diff --git a/dpctl/_sycl_queue.pyx b/dpctl/_sycl_queue.pyx index c20d8724f5..86ef08f584 100644 --- a/dpctl/_sycl_queue.pyx +++ b/dpctl/_sycl_queue.pyx @@ -65,7 +65,15 @@ import ctypes from .enum_types import backend_type from cpython cimport pycapsule -from cpython.buffer cimport PyObject_CheckBuffer +from cpython.buffer cimport ( + Py_buffer, + PyBUF_ANY_CONTIGUOUS, + PyBUF_SIMPLE, + PyBUF_WRITABLE, + PyBuffer_Release, + PyObject_CheckBuffer, + PyObject_GetBuffer, +) from cpython.ref cimport Py_DECREF, Py_INCREF, PyObject from libc.stdlib cimport free, malloc @@ -338,14 +346,20 @@ cdef DPCTLSyclEventRef _memcpy_impl( cdef void *c_dst_ptr = NULL cdef void *c_src_ptr = NULL cdef DPCTLSyclEventRef ERef = NULL - cdef const unsigned char[::1] src_host_buf = None - cdef unsigned char[::1] dst_host_buf = None + cdef Py_buffer src_buf_view + cdef Py_buffer dst_buf_view + cdef bint src_is_buf = False + cdef bint dst_is_buf = False + cdef int ret_code = 0 if isinstance(src, _Memory): c_src_ptr = (<_Memory>src).get_data_ptr() elif _is_buffer(src): - src_host_buf = src - c_src_ptr = &src_host_buf[0] + ret_code = PyObject_GetBuffer(src, &src_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS) + if ret_code != 0: # pragma: no cover + raise RuntimeError("Could not access buffer") + c_src_ptr = src_buf_view.buf + src_is_buf = True else: raise TypeError( "Parameter `src` should have either type " @@ -356,8 +370,13 @@ cdef DPCTLSyclEventRef _memcpy_impl( if isinstance(dst, _Memory): c_dst_ptr = (<_Memory>dst).get_data_ptr() elif _is_buffer(dst): - dst_host_buf = dst - c_dst_ptr = &dst_host_buf[0] + ret_code = PyObject_GetBuffer(dst, &dst_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS | PyBUF_WRITABLE) + if ret_code != 0: # pragma: no cover + if src_is_buf: + PyBuffer_Release(&src_buf_view) + raise RuntimeError("Could not access buffer") + c_dst_ptr = dst_buf_view.buf + dst_is_buf = True else: raise TypeError( "Parameter `dst` should have either type " @@ -376,6 +395,12 @@ cdef DPCTLSyclEventRef _memcpy_impl( dep_events, dep_events_count ) + + if src_is_buf: + PyBuffer_Release(&src_buf_view) + if dst_is_buf: + PyBuffer_Release(&dst_buf_view) + return ERef diff --git a/dpctl/tests/test_sycl_queue_memcpy.py b/dpctl/tests/test_sycl_queue_memcpy.py index 1756cca40a..d134323e74 100644 --- a/dpctl/tests/test_sycl_queue_memcpy.py +++ b/dpctl/tests/test_sycl_queue_memcpy.py @@ -17,6 +17,7 @@ """Defines unit test cases for the SyclQueue.memcpy. """ +import numpy as np import pytest import dpctl @@ -97,6 +98,44 @@ def test_memcpy_copy_host_to_host(): assert dst_buf == src_buf +def test_2D_memcpy_copy_host_to_usm(): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Default constructor for SyclQueue failed") + usm_obj = _create_memory(q) + + n = 12 + canary = bytearray([i for i in range(n)]) + host_obj = np.frombuffer(canary, dtype=np.uint8).reshape(3, 4) + + q.memcpy(usm_obj, host_obj, len(canary)) + + mv2 = memoryview(usm_obj) + + assert mv2[: len(canary)] == canary + + +def test_2D_memcpy_copy_usm_to_host(): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Default constructor for SyclQueue failed") + usm_obj = _create_memory(q) + mv2 = memoryview(usm_obj) + + n = 12 + shape = (3, 4) + for id in range(n): + mv2[id] = id + + host_obj = np.ones(shape, dtype=np.uint8) + + q.memcpy(host_obj, usm_obj, n) + + assert np.array_equal(host_obj, np.arange(n, dtype=np.uint8).reshape(shape)) + + def test_memcpy_async(): try: q = dpctl.SyclQueue()