@@ -65,7 +65,15 @@ import ctypes
65
65
from .enum_types import backend_type
66
66
67
67
from cpython cimport pycapsule
68
- from cpython.buffer cimport PyObject_CheckBuffer
68
+ from cpython.buffer cimport (
69
+ Py_buffer,
70
+ PyBUF_ANY_CONTIGUOUS,
71
+ PyBUF_SIMPLE,
72
+ PyBUF_WRITABLE,
73
+ PyBuffer_Release,
74
+ PyObject_CheckBuffer,
75
+ PyObject_GetBuffer,
76
+ )
69
77
from cpython.ref cimport Py_DECREF, Py_INCREF, PyObject
70
78
from libc.stdlib cimport free, malloc
71
79
@@ -338,14 +346,20 @@ cdef DPCTLSyclEventRef _memcpy_impl(
338
346
cdef void * c_dst_ptr = NULL
339
347
cdef void * c_src_ptr = NULL
340
348
cdef DPCTLSyclEventRef ERef = NULL
341
- cdef const unsigned char [::1 ] src_host_buf = None
342
- cdef unsigned char [::1 ] dst_host_buf = None
349
+ cdef Py_buffer src_buf_view
350
+ cdef Py_buffer dst_buf_view
351
+ cdef bint src_is_buf = False
352
+ cdef bint dst_is_buf = False
353
+ cdef int ret_code = 0
343
354
344
355
if isinstance (src, _Memory):
345
356
c_src_ptr = < void * > (< _Memory> src).get_data_ptr()
346
357
elif _is_buffer(src):
347
- src_host_buf = src
348
- c_src_ptr = < void * > & src_host_buf[0 ]
358
+ ret_code = PyObject_GetBuffer(src, & src_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS)
359
+ if ret_code != 0 : # pragma: no cover
360
+ raise RuntimeError (" Could not access buffer" )
361
+ c_src_ptr = src_buf_view.buf
362
+ src_is_buf = True
349
363
else :
350
364
raise TypeError (
351
365
" Parameter `src` should have either type "
@@ -356,8 +370,13 @@ cdef DPCTLSyclEventRef _memcpy_impl(
356
370
if isinstance (dst, _Memory):
357
371
c_dst_ptr = < void * > (< _Memory> dst).get_data_ptr()
358
372
elif _is_buffer(dst):
359
- dst_host_buf = dst
360
- c_dst_ptr = < void * > & dst_host_buf[0 ]
373
+ ret_code = PyObject_GetBuffer(dst, & dst_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS | PyBUF_WRITABLE)
374
+ if ret_code != 0 : # pragma: no cover
375
+ if src_is_buf:
376
+ PyBuffer_Release(& src_buf_view)
377
+ raise RuntimeError (" Could not access buffer" )
378
+ c_dst_ptr = dst_buf_view.buf
379
+ dst_is_buf = True
361
380
else :
362
381
raise TypeError (
363
382
" Parameter `dst` should have either type "
@@ -376,6 +395,12 @@ cdef DPCTLSyclEventRef _memcpy_impl(
376
395
dep_events,
377
396
dep_events_count
378
397
)
398
+
399
+ if src_is_buf:
400
+ PyBuffer_Release(& src_buf_view)
401
+ if dst_is_buf:
402
+ PyBuffer_Release(& dst_buf_view)
403
+
379
404
return ERef
380
405
381
406
0 commit comments