Skip to content

Copy from/to multidimensional buffers #1985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions dpctl/_sycl_queue.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,15 @@ import ctypes
from .enum_types import backend_type

from cpython cimport pycapsule
from cpython.buffer cimport PyObject_CheckBuffer
from cpython.buffer cimport (
Py_buffer,
PyBUF_ANY_CONTIGUOUS,
PyBUF_SIMPLE,
PyBUF_WRITABLE,
PyBuffer_Release,
PyObject_CheckBuffer,
PyObject_GetBuffer,
)
from cpython.ref cimport Py_DECREF, Py_INCREF, PyObject
from libc.stdlib cimport free, malloc

Expand Down Expand Up @@ -338,14 +346,20 @@ cdef DPCTLSyclEventRef _memcpy_impl(
cdef void *c_dst_ptr = NULL
cdef void *c_src_ptr = NULL
cdef DPCTLSyclEventRef ERef = NULL
cdef const unsigned char[::1] src_host_buf = None
cdef unsigned char[::1] dst_host_buf = None
cdef Py_buffer src_buf_view
cdef Py_buffer dst_buf_view
cdef bint src_is_buf = False
cdef bint dst_is_buf = False
cdef int ret_code = 0

if isinstance(src, _Memory):
c_src_ptr = <void*>(<_Memory>src).get_data_ptr()
elif _is_buffer(src):
src_host_buf = src
c_src_ptr = <void *>&src_host_buf[0]
ret_code = PyObject_GetBuffer(src, &src_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS)
if ret_code != 0: # pragma: no cover
raise RuntimeError("Could not access buffer")
c_src_ptr = src_buf_view.buf
src_is_buf = True
else:
raise TypeError(
"Parameter `src` should have either type "
Expand All @@ -356,8 +370,13 @@ cdef DPCTLSyclEventRef _memcpy_impl(
if isinstance(dst, _Memory):
c_dst_ptr = <void*>(<_Memory>dst).get_data_ptr()
elif _is_buffer(dst):
dst_host_buf = dst
c_dst_ptr = <void *>&dst_host_buf[0]
ret_code = PyObject_GetBuffer(dst, &dst_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS | PyBUF_WRITABLE)
if ret_code != 0: # pragma: no cover
if src_is_buf:
PyBuffer_Release(&src_buf_view)
raise RuntimeError("Could not access buffer")
c_dst_ptr = dst_buf_view.buf
dst_is_buf = True
else:
raise TypeError(
"Parameter `dst` should have either type "
Expand All @@ -376,6 +395,12 @@ cdef DPCTLSyclEventRef _memcpy_impl(
dep_events,
dep_events_count
)

if src_is_buf:
PyBuffer_Release(&src_buf_view)
if dst_is_buf:
PyBuffer_Release(&dst_buf_view)

return ERef


Expand Down
39 changes: 39 additions & 0 deletions dpctl/tests/test_sycl_queue_memcpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Defines unit test cases for the SyclQueue.memcpy.
"""

import numpy as np
import pytest

import dpctl
Expand Down Expand Up @@ -97,6 +98,44 @@ def test_memcpy_copy_host_to_host():
assert dst_buf == src_buf


def test_2D_memcpy_copy_host_to_usm():
try:
q = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
pytest.skip("Default constructor for SyclQueue failed")
usm_obj = _create_memory(q)

n = 12
canary = bytearray([i for i in range(n)])
host_obj = np.frombuffer(canary, dtype=np.uint8).reshape(3, 4)

q.memcpy(usm_obj, host_obj, len(canary))

mv2 = memoryview(usm_obj)

assert mv2[: len(canary)] == canary


def test_2D_memcpy_copy_usm_to_host():
try:
q = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
pytest.skip("Default constructor for SyclQueue failed")
usm_obj = _create_memory(q)
mv2 = memoryview(usm_obj)

n = 12
shape = (3, 4)
for id in range(n):
mv2[id] = id

host_obj = np.ones(shape, dtype=np.uint8)

q.memcpy(host_obj, usm_obj, n)

assert np.array_equal(host_obj, np.arange(n, dtype=np.uint8).reshape(shape))


def test_memcpy_async():
try:
q = dpctl.SyclQueue()
Expand Down
Loading