@@ -1028,22 +1028,22 @@ def from_dlpack(x, /, *, device=None, copy=None):
1028
1028
f" The argument of type {type(x)} does not implement "
1029
1029
" `__dlpack__` and `__dlpack_device__` methods."
1030
1030
)
1031
- try :
1032
- # device is converted to a dlpack_device if necessary
1033
- dl_device = None
1034
- if device:
1035
- if isinstance (device, tuple ):
1036
- dl_device = device
1037
- if len (dl_device) != 2 :
1038
- raise ValueError (
1039
- " Argument `device` specified as a tuple must have length 2"
1040
- )
1031
+ # device is converted to a dlpack_device if necessary
1032
+ dl_device = None
1033
+ if device:
1034
+ if isinstance (device, tuple ):
1035
+ dl_device = device
1036
+ if len (dl_device) != 2 :
1037
+ raise ValueError (
1038
+ " Argument `device` specified as a tuple must have length 2"
1039
+ )
1040
+ else :
1041
+ if not isinstance (device, dpctl.SyclDevice):
1042
+ d = Device.create_device(device).sycl_device
1041
1043
else :
1042
- if not isinstance (device, dpctl.SyclDevice):
1043
- d = Device.create_device(device).sycl_device
1044
- else :
1045
- d = device
1046
- dl_device = (device_OneAPI, get_parent_device_ordinal_id(< c_dpctl.SyclDevice> d))
1044
+ d = device
1045
+ dl_device = (device_OneAPI, get_parent_device_ordinal_id(< c_dpctl.SyclDevice> d))
1046
+ try :
1047
1047
dlpack_capsule = dlpack_attr(max_version = get_build_dlpack_version(), dl_device = dl_device, copy = copy)
1048
1048
return from_dlpack_capsule(dlpack_capsule)
1049
1049
except TypeError :
@@ -1058,7 +1058,8 @@ def from_dlpack(x, /, *, device=None, copy=None):
1058
1058
" Importing data via DLPack requires copying, but copy=False was provided"
1059
1059
)
1060
1060
if x_dldev == (device_CPU, 0 ) and dl_device[0 ] == device_OneAPI:
1061
- host_blob = x
1061
+ dlpack_capsule = dlpack_attr()
1062
+ host_blob = from_dlpack_capsule(dlpack_capsule)
1062
1063
else :
1063
1064
raise BufferError(f" Can not import to requested device {dl_device}" )
1064
1065
return _to_usm_ary_from_host_blob(host_blob, dl_device[1 ])
@@ -1074,7 +1075,8 @@ def from_dlpack(x, /, *, device=None, copy=None):
1074
1075
raise BufferError(f" Can not import to requested device {dl_device}" )
1075
1076
x_dldev = dlpack_dev_attr()
1076
1077
if x_dldev == (device_CPU, 0 ):
1077
- host_blob = x
1078
+ dlpack_capsule = dlpack_attr()
1079
+ host_blob = from_dlpack_capsule(dlpack_capsule)
1078
1080
else :
1079
1081
dlpack_capsule = dlpack_attr(max_version = (1 , 0 ), dl_device = (device_CPU, 0 ), copy = copy)
1080
1082
host_blob = from_dlpack_capsule(dlpack_capsule)
0 commit comments