Skip to content

Commit 2661f51

Browse files
Support copy-via-host in from_dlpack
For arr that supports DLPack, version (1, 0), or legacy, support ``` from_dlpack(arr, device=target_dev) ``` where target_dev is `(kDLCPU, 0)` for transfer to host, or a value recognized by device keywords in dpctl.tensor for other functions, or `(kDLOneAPI, dev_id)`. To support transfer via host, `arr` must support `__dlpack__(max_version=(1,0), dl_device=(1, 0))`. For array objects with legacy `__dlpack__` support only, supported inputs are those residing on kDLCPU device, or those from kDLOneAPI device only. --- This is a combination of 17 commits squashed into one: Combine two validation checks into one, improving coverage Only fall-back to __dlpack__() if requested device does not change Simplify branching, only fall-back to no-arg call to __dlpack__ is dl_device is None or same as reported for the input Changed from_dlpack to copy via host is needed This enables dpt.from_dlpack(numpy_array, device="opencl:cpu") Add a test to exercise copy via host Handle possibilities for TypeError and BufferError These may be hard to test Change exception raised by __dlpack__ if dl_device is unsupported It used to raise NotImplementedError, not raises BufferError Add case of dlpack test to expand coverage Removed comment, add NotImplementedError to the except clause To ensure same validation across branches, compute host_blob by roundtripping it through dlpack Test from_dlpack on numpy input with strides not multiple of elementsize Refined from_dlpack docstrings, reorged impl of from_dlpack Used try/except/else/finally to avoid raising an exception when another one is in flight (confusing UX). device keyword is only allowed to be (kDLCPU, 0) or (kDLOneAPI, num). Device keyword value is used to create output array, rather than device_id deduced from it. Adjusted test per change in implementation Expand applicability of fall-back behavior When `from_dlpack(arr, device=dev)` is called, for `arr` object that supports legacy DLPack interface (max_version, dl_device, copy are not supported), we now support arr being device on host, that is (kDLCPU, 0), and (kDLOneAPI, different_device_id). Support for this last case is being added in this commit, as per review comment. Add symmetric support for containers with legacy DLPack support For legacy containers, support device=(kDLCPU, 0) as well as oneAPI device. Add tests for importing generic legacy and generic modern containers Fix typos in comments Add test for legacy container holding numpy's array.
1 parent dd4c0c0 commit 2661f51

File tree

3 files changed

+281
-28
lines changed

3 files changed

+281
-28
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 151 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ cdef void _managed_tensor_versioned_deleter(DLManagedTensorVersioned *dlmv_tenso
168168
stdlib.free(dlmv_tensor)
169169

170170

171-
cdef object _get_default_context(c_dpctl.SyclDevice dev) except *:
171+
cdef object _get_default_context(c_dpctl.SyclDevice dev):
172172
try:
173173
default_context = dev.sycl_platform.default_context
174174
except RuntimeError:
@@ -178,7 +178,7 @@ cdef object _get_default_context(c_dpctl.SyclDevice dev) except *:
178178
return default_context
179179

180180

181-
cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *:
181+
cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except -1:
182182
cdef DPCTLSyclDeviceRef pDRef = NULL
183183
cdef DPCTLSyclDeviceRef tDRef = NULL
184184
cdef c_dpctl.SyclDevice p_dev
@@ -201,7 +201,7 @@ cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *:
201201

202202
cdef int get_array_dlpack_device_id(
203203
usm_ndarray usm_ary
204-
) except *:
204+
) except -1:
205205
"""Finds ordinal number of the parent of device where array
206206
was allocated.
207207
"""
@@ -935,6 +935,32 @@ cpdef object from_dlpack_capsule(object py_caps):
935935
"The DLPack tensor resides on unsupported device."
936936
)
937937

938+
cdef usm_ndarray _to_usm_ary_from_host_blob(object host_blob, dev : Device):
939+
q = dev.sycl_queue
940+
np_ary = np.asarray(host_blob)
941+
dt = np_ary.dtype
942+
if dt.char in "dD" and q.sycl_device.has_aspect_fp64 is False:
943+
Xusm_dtype = (
944+
"float32" if dt.char == "d" else "complex64"
945+
)
946+
else:
947+
Xusm_dtype = dt
948+
usm_mem = dpmem.MemoryUSMDevice(np_ary.nbytes, queue=q)
949+
usm_ary = usm_ndarray(np_ary.shape, dtype=Xusm_dtype, buffer=usm_mem)
950+
usm_mem.copy_from_host(np.reshape(np_ary.view(dtype="u1"), -1))
951+
return usm_ary
952+
953+
954+
# only cdef to make it private
955+
cdef object _create_device(object device, object dl_device):
956+
if isinstance(device, Device):
957+
return device
958+
elif isinstance(device, dpctl.SyclDevice):
959+
return Device.create_device(device)
960+
else:
961+
root_device = dpctl.SyclDevice(str(<int>dl_device[1]))
962+
return Device.create_device(root_device)
963+
938964

939965
def from_dlpack(x, /, *, device=None, copy=None):
940966
""" from_dlpack(x, /, *, device=None, copy=None)
@@ -943,7 +969,7 @@ def from_dlpack(x, /, *, device=None, copy=None):
943969
object ``x`` that implements ``__dlpack__`` protocol.
944970
945971
Args:
946-
x (Python object):
972+
x (object):
947973
A Python object representing an array that supports
948974
``__dlpack__`` protocol.
949975
device (Optional[str,
@@ -959,7 +985,8 @@ def from_dlpack(x, /, *, device=None, copy=None):
959985
returned by :attr:`dpctl.tensor.usm_ndarray.device`, or a
960986
2-tuple matching the format of the output of the ``__dlpack_device__``
961987
method, an integer enumerator representing the device type followed by
962-
an integer representing the index of the device.
988+
an integer representing the index of the device. The only supported
989+
:enum:`dpctl.tensor.DLDeviceType` types are "kDLCPU" and "kDLOneAPI".
963990
Default: ``None``.
964991
copy (bool, optional)
965992
Boolean indicating whether or not to copy the input.
@@ -1008,33 +1035,130 @@ def from_dlpack(x, /, *, device=None, copy=None):
10081035
10091036
C = Container(dpt.linspace(0, 100, num=20, dtype="int16"))
10101037
X = dpt.from_dlpack(C)
1038+
Y = dpt.from_dlpack(C, device=(dpt.DLDeviceType.kDLCPU, 0))
10111039
10121040
"""
1013-
if not hasattr(x, "__dlpack__"):
1014-
raise TypeError(
1015-
f"The argument of type {type(x)} does not implement "
1016-
"`__dlpack__` method."
1017-
)
1018-
dlpack_attr = getattr(x, "__dlpack__")
1019-
if not callable(dlpack_attr):
1041+
dlpack_attr = getattr(x, "__dlpack__", None)
1042+
dlpack_dev_attr = getattr(x, "__dlpack_device__", None)
1043+
if not callable(dlpack_attr) or not callable(dlpack_dev_attr):
10201044
raise TypeError(
10211045
f"The argument of type {type(x)} does not implement "
1022-
"`__dlpack__` method."
1046+
"`__dlpack__` and `__dlpack_device__` methods."
10231047
)
1024-
try:
1025-
# device is converted to a dlpack_device if necessary
1026-
dl_device = None
1027-
if device:
1028-
if isinstance(device, tuple):
1029-
dl_device = device
1048+
# device is converted to a dlpack_device if necessary
1049+
dl_device = None
1050+
if device:
1051+
if isinstance(device, tuple):
1052+
dl_device = device
1053+
if len(dl_device) != 2:
1054+
raise ValueError(
1055+
"Argument `device` specified as a tuple must have length 2"
1056+
)
1057+
else:
1058+
if not isinstance(device, dpctl.SyclDevice):
1059+
device = Device.create_device(device)
1060+
d = device.sycl_device
10301061
else:
1031-
if not isinstance(device, dpctl.SyclDevice):
1032-
d = Device.create_device(device).sycl_device
1033-
dl_device = (device_OneAPI, get_parent_device_ordinal_id(<c_dpctl.SyclDevice>d))
1034-
else:
1035-
dl_device = (device_OneAPI, get_parent_device_ordinal_id(<c_dpctl.SyclDevice>device))
1036-
dlpack_capsule = dlpack_attr(max_version=get_build_dlpack_version(), dl_device=dl_device, copy=copy)
1037-
return from_dlpack_capsule(dlpack_capsule)
1062+
d = device
1063+
dl_device = (device_OneAPI, get_parent_device_ordinal_id(<c_dpctl.SyclDevice>d))
1064+
if dl_device is not None:
1065+
if (dl_device[0] not in [device_OneAPI, device_CPU]):
1066+
raise ValueError(
1067+
f"Argument `device`={device} is not supported."
1068+
)
1069+
got_type_error = False
1070+
got_buffer_error = False
1071+
got_other_error = False
1072+
saved_exception = None
1073+
# First DLPack version supporting dl_device, and copy
1074+
requested_ver = (1, 0)
1075+
cpu_dev = (device_CPU, 0)
1076+
try:
1077+
# setting max_version to minimal version that supports dl_device/copy keywords
1078+
dlpack_capsule = dlpack_attr(
1079+
max_version=requested_ver,
1080+
dl_device=dl_device,
1081+
copy=copy
1082+
)
10381083
except TypeError:
1039-
dlpack_capsule = dlpack_attr()
1084+
# exporter does not support max_version keyword
1085+
got_type_error = True
1086+
except (BufferError, NotImplementedError):
1087+
# Either dl_device, or copy can be satisfied
1088+
got_buffer_error = True
1089+
except Exception as e:
1090+
got_other_error = True
1091+
saved_exception = e
1092+
else:
1093+
# execution did not raise exceptions
10401094
return from_dlpack_capsule(dlpack_capsule)
1095+
finally:
1096+
if got_type_error:
1097+
# max_version/dl_device, copy keywords are not supported by __dlpack__
1098+
x_dldev = dlpack_dev_attr()
1099+
if (dl_device is None) or (dl_device == x_dldev):
1100+
dlpack_capsule = dlpack_attr()
1101+
return from_dlpack_capsule(dlpack_capsule)
1102+
# must copy via host
1103+
if copy is False:
1104+
raise BufferError(
1105+
"Importing data via DLPack requires copying, but copy=False was provided"
1106+
)
1107+
# when max_version/dl_device/copy are not supported
1108+
# we can only support importing to OneAPI devices
1109+
# from host, or from another oneAPI device
1110+
is_supported_x_dldev = (
1111+
x_dldev == cpu_dev or
1112+
(x_dldev[0] == device_OneAPI)
1113+
)
1114+
is_supported_dl_device = (
1115+
dl_device == cpu_dev or
1116+
dl_device[0] == device_OneAPI
1117+
)
1118+
if is_supported_x_dldev and is_supported_dl_device:
1119+
dlpack_capsule = dlpack_attr()
1120+
blob = from_dlpack_capsule(dlpack_capsule)
1121+
else:
1122+
raise BufferError(f"Can not import to requested device {dl_device}")
1123+
dev = _create_device(device, dl_device)
1124+
if x_dldev == cpu_dev and dl_device == cpu_dev:
1125+
# both source and destination are CPU
1126+
return blob
1127+
elif x_dldev == cpu_dev:
1128+
# source is CPU, destination is oneAPI
1129+
return _to_usm_ary_from_host_blob(blob, dev)
1130+
elif dl_device == cpu_dev:
1131+
# source is oneAPI, destination is CPU
1132+
cpu_caps = blob.__dlpack__(
1133+
max_version=get_build_dlpack_version(),
1134+
dl_device=cpu_dev
1135+
)
1136+
return from_dlpack_capsule(cpu_caps)
1137+
else:
1138+
import dpctl.tensor as dpt
1139+
return dpt.asarray(blob, device=dev)
1140+
elif got_buffer_error:
1141+
# we are here, because dlpack_attr could not deal with requested dl_device,
1142+
# or copying was required
1143+
if copy is False:
1144+
raise BufferError(
1145+
"Importing data via DLPack requires copying, but copy=False was provided"
1146+
)
1147+
# must copy via host
1148+
if dl_device[0] != device_OneAPI:
1149+
raise BufferError(f"Can not import to requested device {dl_device}")
1150+
x_dldev = dlpack_dev_attr()
1151+
if x_dldev == cpu_dev:
1152+
dlpack_capsule = dlpack_attr()
1153+
host_blob = from_dlpack_capsule(dlpack_capsule)
1154+
else:
1155+
dlpack_capsule = dlpack_attr(
1156+
max_version=requested_ver,
1157+
dl_device=cpu_dev,
1158+
copy=copy
1159+
)
1160+
host_blob = from_dlpack_capsule(dlpack_capsule)
1161+
dev = _create_device(device, dl_device)
1162+
return _to_usm_ary_from_host_blob(host_blob, dev)
1163+
elif got_other_error:
1164+
raise saved_exception

dpctl/tensor/_usmarray.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1242,7 +1242,7 @@ cdef class usm_ndarray:
12421242
_arr.flags["W"] = self.flags["W"]
12431243
return c_dlpack.numpy_to_dlpack_versioned_capsule(_arr, True)
12441244
else:
1245-
raise NotImplementedError(
1245+
raise BufferError(
12461246
f"targeting `dl_device` {dl_device} with `__dlpack__` is not "
12471247
"yet implemented"
12481248
)

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,3 +696,132 @@ def test_dlpack_size_0_on_kdlcpu():
696696
cap = x_np.__dlpack__()
697697
y = _dlp.from_dlpack_capsule(cap)
698698
assert y.ctypes.data == x_np.ctypes.data
699+
700+
701+
def test_copy_via_host():
702+
get_queue_or_skip()
703+
x = dpt.ones(1, dtype="i4")
704+
x_np = np.ones(1, dtype="i4")
705+
x_dl_dev = x.__dlpack_device__()
706+
y = dpt.from_dlpack(x_np, device=x_dl_dev)
707+
assert isinstance(y, dpt.usm_ndarray)
708+
assert y.sycl_device == x.sycl_device
709+
assert y.usm_type == "device"
710+
711+
with pytest.raises(ValueError):
712+
# uncorrect length of tuple
713+
dpt.from_dlpack(x_np, device=(1, 0, 0))
714+
with pytest.raises(ValueError):
715+
# only kDLCPU and kDLOneAPI are supported
716+
dpt.from_dlpack(x, device=(2, 0))
717+
718+
num_devs = dpctl.get_num_devices()
719+
if num_devs > 1:
720+
j = [i for i in range(num_devs) if i != x_dl_dev[1]][0]
721+
z = dpt.from_dlpack(x, device=(x_dl_dev[0], j))
722+
assert isinstance(z, dpt.usm_ndarray)
723+
assert z.usm_type == "device"
724+
725+
726+
def test_copy_via_host_gh_1789():
727+
"Test based on review example from gh-1789"
728+
get_queue_or_skip()
729+
x_np = np.ones((10, 10), dtype="i4")
730+
# strides are no longer multiple of itemsize
731+
x_np.strides = (x_np.strides[0] - 1, x_np.strides[1])
732+
with pytest.raises(BufferError):
733+
dpt.from_dlpack(x_np)
734+
with pytest.raises(BufferError):
735+
dpt.from_dlpack(x_np, device=(14, 0))
736+
737+
738+
class LegacyContainer:
739+
"Helper class implementing legacy `__dlpack__` protocol"
740+
741+
def __init__(self, array):
742+
self._array = array
743+
744+
def __dlpack__(self, stream=None):
745+
return self._array.__dlpack__(stream=stream)
746+
747+
def __dlpack_device__(self):
748+
return self._array.__dlpack_device__()
749+
750+
751+
class Container:
752+
"Helper class implementing legacy `__dlpack__` protocol"
753+
754+
def __init__(self, array):
755+
self._array = array
756+
757+
def __dlpack__(
758+
self, max_version=None, dl_device=None, copy=None, stream=None
759+
):
760+
return self._array.__dlpack__(
761+
max_version=max_version,
762+
dl_device=dl_device,
763+
copy=copy,
764+
stream=stream,
765+
)
766+
767+
def __dlpack_device__(self):
768+
return self._array.__dlpack_device__()
769+
770+
771+
def test_generic_container_legacy():
772+
get_queue_or_skip()
773+
C = LegacyContainer(dpt.linspace(0, 100, num=20, dtype="int16"))
774+
775+
X = dpt.from_dlpack(C)
776+
assert isinstance(X, dpt.usm_ndarray)
777+
assert X._pointer == C._array._pointer
778+
assert X.sycl_device == C._array.sycl_device
779+
assert X.dtype == C._array.dtype
780+
781+
Y = dpt.from_dlpack(C, device=(dpt.DLDeviceType.kDLCPU, 0))
782+
assert isinstance(Y, np.ndarray)
783+
assert Y.dtype == X.dtype
784+
785+
Z = dpt.from_dlpack(C, device=X.device)
786+
assert isinstance(Z, dpt.usm_ndarray)
787+
assert Z._pointer == X._pointer
788+
assert Z.device == X.device
789+
790+
791+
def test_generic_container_legacy_np():
792+
get_queue_or_skip()
793+
C = LegacyContainer(np.linspace(0, 100, num=20, dtype="int16"))
794+
795+
X = dpt.from_dlpack(C)
796+
assert isinstance(X, np.ndarray)
797+
assert X.ctypes.data == C._array.ctypes.data
798+
assert X.dtype == C._array.dtype
799+
800+
Y = dpt.from_dlpack(C, device=(dpt.DLDeviceType.kDLCPU, 0))
801+
assert isinstance(Y, np.ndarray)
802+
assert Y.dtype == X.dtype
803+
804+
dev = dpt.Device.create_device()
805+
Z = dpt.from_dlpack(C, device=dev)
806+
assert isinstance(Z, dpt.usm_ndarray)
807+
assert Z.device == dev
808+
809+
810+
def test_generic_container():
811+
get_queue_or_skip()
812+
C = Container(dpt.linspace(0, 100, num=20, dtype="int16"))
813+
814+
X = dpt.from_dlpack(C)
815+
assert isinstance(X, dpt.usm_ndarray)
816+
assert X._pointer == C._array._pointer
817+
assert X.sycl_device == C._array.sycl_device
818+
assert X.dtype == C._array.dtype
819+
820+
Y = dpt.from_dlpack(C, device=(dpt.DLDeviceType.kDLCPU, 0))
821+
assert isinstance(Y, np.ndarray)
822+
assert Y.dtype == X.dtype
823+
824+
Z = dpt.from_dlpack(C, device=X.device)
825+
assert isinstance(Z, dpt.usm_ndarray)
826+
assert Z._pointer == X._pointer
827+
assert Z.device == X.device

0 commit comments

Comments
 (0)