Skip to content

Commit 6f4cd55

Browse files
To ensure same validation across branches, compute host_blob by roundtripping it through dlpack
1 parent e18f6c1 commit 6f4cd55

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,22 +1028,22 @@ def from_dlpack(x, /, *, device=None, copy=None):
10281028
f"The argument of type {type(x)} does not implement "
10291029
"`__dlpack__` and `__dlpack_device__` methods."
10301030
)
1031-
try:
1032-
# device is converted to a dlpack_device if necessary
1033-
dl_device = None
1034-
if device:
1035-
if isinstance(device, tuple):
1036-
dl_device = device
1037-
if len(dl_device) != 2:
1038-
raise ValueError(
1039-
"Argument `device` specified as a tuple must have length 2"
1040-
)
1031+
# device is converted to a dlpack_device if necessary
1032+
dl_device = None
1033+
if device:
1034+
if isinstance(device, tuple):
1035+
dl_device = device
1036+
if len(dl_device) != 2:
1037+
raise ValueError(
1038+
"Argument `device` specified as a tuple must have length 2"
1039+
)
1040+
else:
1041+
if not isinstance(device, dpctl.SyclDevice):
1042+
d = Device.create_device(device).sycl_device
10411043
else:
1042-
if not isinstance(device, dpctl.SyclDevice):
1043-
d = Device.create_device(device).sycl_device
1044-
else:
1045-
d = device
1046-
dl_device = (device_OneAPI, get_parent_device_ordinal_id(<c_dpctl.SyclDevice>d))
1044+
d = device
1045+
dl_device = (device_OneAPI, get_parent_device_ordinal_id(<c_dpctl.SyclDevice>d))
1046+
try:
10471047
dlpack_capsule = dlpack_attr(max_version=get_build_dlpack_version(), dl_device=dl_device, copy=copy)
10481048
return from_dlpack_capsule(dlpack_capsule)
10491049
except TypeError:
@@ -1058,7 +1058,8 @@ def from_dlpack(x, /, *, device=None, copy=None):
10581058
"Importing data via DLPack requires copying, but copy=False was provided"
10591059
)
10601060
if x_dldev == (device_CPU, 0) and dl_device[0] == device_OneAPI:
1061-
host_blob = x
1061+
dlpack_capsule = dlpack_attr()
1062+
host_blob = from_dlpack_capsule(dlpack_capsule)
10621063
else:
10631064
raise BufferError(f"Can not import to requested device {dl_device}")
10641065
return _to_usm_ary_from_host_blob(host_blob, dl_device[1])
@@ -1074,7 +1075,8 @@ def from_dlpack(x, /, *, device=None, copy=None):
10741075
raise BufferError(f"Can not import to requested device {dl_device}")
10751076
x_dldev = dlpack_dev_attr()
10761077
if x_dldev == (device_CPU, 0):
1077-
host_blob = x
1078+
dlpack_capsule = dlpack_attr()
1079+
host_blob = from_dlpack_capsule(dlpack_capsule)
10781080
else:
10791081
dlpack_capsule = dlpack_attr(max_version=(1, 0), dl_device=(device_CPU, 0), copy=copy)
10801082
host_blob = from_dlpack_capsule(dlpack_capsule)

0 commit comments

Comments
 (0)