Skip to content

Commit cb80f55

Browse files
committed
Adds validation for dl_device argument in __dlpack__
1 parent c3655ed commit cb80f55

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,12 @@ cdef class usm_ndarray:
11671167
if max_version[0] >= dpctl_dlpack_version[0]:
11681168
# DLManagedTensorVersioned path
11691169
if dl_device is not None:
1170+
if not isinstance(dl_device, tuple) or len(dl_device) != 2:
1171+
raise TypeError(
1172+
"`__dlpack__` expects `dl_device` to be a "
1173+
"2-tuple of `(device_type, device_id)`, instead "
1174+
f"got {type(dl_device)}"
1175+
)
11701176
if dl_device != self.__dlpack_device__():
11711177
if copy == False:
11721178
raise BufferError(

0 commit comments

Comments
 (0)