Skip to content

Commit 3c35c3b

Browse files
committed
Chane per PR review by @oleksandr-pavlyk and pass NULL strides to dl_tensor if NumPy array is C-contiguous
1 parent cb80f55 commit 3c35c3b

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -515,27 +515,43 @@ cpdef numpy_to_dlpack_versioned_capsule(ndarray npy_ary, bint copied):
515515
cdef int i = 0
516516
cdef int device_id = -1
517517
cdef Py_ssize_t byte_offset = 0
518+
cdef int itemsize = npy_ary.itemsize
518519

519520
dlmv_tensor = <DLManagedTensorVersioned *> stdlib.malloc(
520521
sizeof(DLManagedTensorVersioned))
521522
if dlmv_tensor is NULL:
522523
raise MemoryError(
523-
"to_dlpack_versioned_capsule: Could not allocate memory "
524+
"numpy_to_dlpack_versioned_capsule: Could not allocate memory "
524525
"for DLManagedTensorVersioned"
525526
)
526-
shape_strides_ptr = <int64_t *>stdlib.malloc((sizeof(int64_t) * 2) * nd)
527+
528+
is_c_contiguous = npy_ary.flags["C"]
529+
shape = npy_ary.ctypes.shape_as(ctypes.c_int64)
530+
strides = npy_ary.ctypes.strides_as(ctypes.c_int64)
531+
if not is_c_contiguous:
532+
if npy_ary.size != 1:
533+
for i in range(nd):
534+
if shape[i] != 1 and strides[i] % itemsize != 0:
535+
stdlib.free(dlmv_tensor)
536+
raise BufferError(
537+
"numpy_to_dlpack_versioned_capsule: DLPack cannot encode "
538+
"an array if strides are not a multiple of itemsize"
539+
)
540+
shape_strides_ptr = <int64_t *>stdlib.malloc((sizeof(int64_t) * 2) * nd)
541+
else:
542+
# no need to pass strides in this case
543+
shape_strides_ptr = <int64_t *>stdlib.malloc(sizeof(int64_t) * nd)
527544
if shape_strides_ptr is NULL:
528545
stdlib.free(dlmv_tensor)
529546
raise MemoryError(
530-
"to_dlpack_versioned_capsule: Could not allocate memory "
547+
"numpy_to_dlpack_versioned_capsule: Could not allocate memory "
531548
"for shape/strides"
532549
)
533-
# this can be a separate function for handling shapes and strides
534-
shape = npy_ary.ctypes.shape_as(ctypes.c_int64)
535-
strides = npy_ary.ctypes.strides_as(ctypes.c_int64)
536550
for i in range(nd):
537551
shape_strides_ptr[i] = shape[i]
538-
shape_strides_ptr[nd + i] = strides[i] // npy_ary.itemsize
552+
if not is_c_contiguous:
553+
shape_strides_ptr[nd + i] = strides[i] // itemsize
554+
539555
writable_flag = npy_ary.flags["W"]
540556

541557
ary_dt = npy_ary.dtype
@@ -546,7 +562,7 @@ cpdef numpy_to_dlpack_versioned_capsule(ndarray npy_ary, bint copied):
546562
dl_tensor.ndim = nd
547563
dl_tensor.byte_offset = <uint64_t>byte_offset
548564
dl_tensor.shape = &shape_strides_ptr[0]
549-
dl_tensor.strides = &shape_strides_ptr[nd]
565+
dl_tensor.strides = &shape_strides_ptr[nd] if not is_c_contiguous else NULL
550566
dl_tensor.device.device_type = kDLCPU
551567
dl_tensor.device.device_id = 0
552568
dl_tensor.dtype.lanes = <uint16_t>1

0 commit comments

Comments
 (0)