Skip to content

Commit 5529d66

Browse files
Merge pull request #1985 from sommerlukas/nd-memory-copy
Copy from/to multidimensional buffers
2 parents a332dd5 + a8bc75e commit 5529d66

File tree

2 files changed

+71
-7
lines changed

2 files changed

+71
-7
lines changed

dpctl/_sycl_queue.pyx

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,15 @@ import ctypes
6565
from .enum_types import backend_type
6666

6767
from cpython cimport pycapsule
68-
from cpython.buffer cimport PyObject_CheckBuffer
68+
from cpython.buffer cimport (
69+
Py_buffer,
70+
PyBUF_ANY_CONTIGUOUS,
71+
PyBUF_SIMPLE,
72+
PyBUF_WRITABLE,
73+
PyBuffer_Release,
74+
PyObject_CheckBuffer,
75+
PyObject_GetBuffer,
76+
)
6977
from cpython.ref cimport Py_DECREF, Py_INCREF, PyObject
7078
from libc.stdlib cimport free, malloc
7179

@@ -338,14 +346,20 @@ cdef DPCTLSyclEventRef _memcpy_impl(
338346
cdef void *c_dst_ptr = NULL
339347
cdef void *c_src_ptr = NULL
340348
cdef DPCTLSyclEventRef ERef = NULL
341-
cdef const unsigned char[::1] src_host_buf = None
342-
cdef unsigned char[::1] dst_host_buf = None
349+
cdef Py_buffer src_buf_view
350+
cdef Py_buffer dst_buf_view
351+
cdef bint src_is_buf = False
352+
cdef bint dst_is_buf = False
353+
cdef int ret_code = 0
343354

344355
if isinstance(src, _Memory):
345356
c_src_ptr = <void*>(<_Memory>src).get_data_ptr()
346357
elif _is_buffer(src):
347-
src_host_buf = src
348-
c_src_ptr = <void *>&src_host_buf[0]
358+
ret_code = PyObject_GetBuffer(src, &src_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS)
359+
if ret_code != 0: # pragma: no cover
360+
raise RuntimeError("Could not access buffer")
361+
c_src_ptr = src_buf_view.buf
362+
src_is_buf = True
349363
else:
350364
raise TypeError(
351365
"Parameter `src` should have either type "
@@ -356,8 +370,13 @@ cdef DPCTLSyclEventRef _memcpy_impl(
356370
if isinstance(dst, _Memory):
357371
c_dst_ptr = <void*>(<_Memory>dst).get_data_ptr()
358372
elif _is_buffer(dst):
359-
dst_host_buf = dst
360-
c_dst_ptr = <void *>&dst_host_buf[0]
373+
ret_code = PyObject_GetBuffer(dst, &dst_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS | PyBUF_WRITABLE)
374+
if ret_code != 0: # pragma: no cover
375+
if src_is_buf:
376+
PyBuffer_Release(&src_buf_view)
377+
raise RuntimeError("Could not access buffer")
378+
c_dst_ptr = dst_buf_view.buf
379+
dst_is_buf = True
361380
else:
362381
raise TypeError(
363382
"Parameter `dst` should have either type "
@@ -376,6 +395,12 @@ cdef DPCTLSyclEventRef _memcpy_impl(
376395
dep_events,
377396
dep_events_count
378397
)
398+
399+
if src_is_buf:
400+
PyBuffer_Release(&src_buf_view)
401+
if dst_is_buf:
402+
PyBuffer_Release(&dst_buf_view)
403+
379404
return ERef
380405

381406

dpctl/tests/test_sycl_queue_memcpy.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""Defines unit test cases for the SyclQueue.memcpy.
1818
"""
1919

20+
import numpy as np
2021
import pytest
2122

2223
import dpctl
@@ -97,6 +98,44 @@ def test_memcpy_copy_host_to_host():
9798
assert dst_buf == src_buf
9899

99100

101+
def test_2D_memcpy_copy_host_to_usm():
102+
try:
103+
q = dpctl.SyclQueue()
104+
except dpctl.SyclQueueCreationError:
105+
pytest.skip("Default constructor for SyclQueue failed")
106+
usm_obj = _create_memory(q)
107+
108+
n = 12
109+
canary = bytearray([i for i in range(n)])
110+
host_obj = np.frombuffer(canary, dtype=np.uint8).reshape(3, 4)
111+
112+
q.memcpy(usm_obj, host_obj, len(canary))
113+
114+
mv2 = memoryview(usm_obj)
115+
116+
assert mv2[: len(canary)] == canary
117+
118+
119+
def test_2D_memcpy_copy_usm_to_host():
120+
try:
121+
q = dpctl.SyclQueue()
122+
except dpctl.SyclQueueCreationError:
123+
pytest.skip("Default constructor for SyclQueue failed")
124+
usm_obj = _create_memory(q)
125+
mv2 = memoryview(usm_obj)
126+
127+
n = 12
128+
shape = (3, 4)
129+
for id in range(n):
130+
mv2[id] = id
131+
132+
host_obj = np.ones(shape, dtype=np.uint8)
133+
134+
q.memcpy(host_obj, usm_obj, n)
135+
136+
assert np.array_equal(host_obj, np.arange(n, dtype=np.uint8).reshape(shape))
137+
138+
100139
def test_memcpy_async():
101140
try:
102141
q = dpctl.SyclQueue()

0 commit comments

Comments
 (0)