diff --git a/dpctl/__init__.py b/dpctl/__init__.py index 1d9b7209e4..e4dd710ade 100644 --- a/dpctl/__init__.py +++ b/dpctl/__init__.py @@ -48,6 +48,7 @@ from ._sycl_event import SyclEvent from ._sycl_platform import SyclPlatform, get_platforms, lsplatform from ._sycl_queue import ( + LocalAccessor, SyclKernelInvalidRangeError, SyclKernelSubmitError, SyclQueue, @@ -102,6 +103,7 @@ "SyclKernelSubmitError", "SyclQueueCreationError", "WorkGroupMemory", + "LocalAccessor", ] __all__ += [ "get_device_cached_queue", diff --git a/dpctl/_backend.pxd b/dpctl/_backend.pxd index 1f27a1f540..9b41499160 100644 --- a/dpctl/_backend.pxd +++ b/dpctl/_backend.pxd @@ -362,6 +362,12 @@ cdef extern from "syclinterface/dpctl_sycl_kernel_bundle_interface.h": cdef extern from "syclinterface/dpctl_sycl_queue_interface.h": + ctypedef struct _md_local_accessor 'MDLocalAccessor': + size_t ndim + _arg_data_type dpctl_type_id + size_t dim0 + size_t dim1 + size_t dim2 cdef bool DPCTLQueue_AreEq(const DPCTLSyclQueueRef QRef1, const DPCTLSyclQueueRef QRef2) cdef DPCTLSyclQueueRef DPCTLQueue_Create( diff --git a/dpctl/_sycl_queue.pyx b/dpctl/_sycl_queue.pyx index 94527506ef..ad44e8faa2 100644 --- a/dpctl/_sycl_queue.pyx +++ b/dpctl/_sycl_queue.pyx @@ -59,6 +59,7 @@ from ._backend cimport ( # noqa: E211 DPCTLWorkGroupMemory_Delete, _arg_data_type, _backend_type, + _md_local_accessor, _queue_property_type, ) from .memory._memory cimport _Memory @@ -125,6 +126,95 @@ cdef class kernel_arg_type_attribute: return self.attr_value +cdef class LocalAccessor: + """ + LocalAccessor(dtype, shape) + + Python class for specifying the dimensionality and type of a + ``sycl::local_accessor``, to be used as a kernel argument type. + + Args: + dtype (str): + the data type of the local memory. + The permitted values are + + `'i1'`, `'i2'`, `'i4'`, `'i8'`: + signed integral types int8_t, int16_t, int32_t, int64_t + `'u1'`, `'u2'`, `'u4'`, `'u8'` + unsigned integral types uint8_t, uint16_t, uint32_t, + uint64_t + `'f4'`, `'f8'`, + single- and double-precision floating-point types float and + double + shape (tuple, list): + Size of LocalAccessor dimensions. Dimension of the LocalAccessor is + determined by the length of the tuple. Must be of length 1, 2, or 3, + and contain only non-negative integers. + + Raises: + TypeError: + If the given shape is not a tuple or list. + ValueError: + If the given shape sequence is not between one and three elements long. + TypeError: + If the shape is not a sequence of integers. + ValueError: + If the shape contains a negative integer. + ValueError: + If the dtype string is unrecognized. + """ + cdef _md_local_accessor lacc + + def __cinit__(self, str dtype, shape): + if not isinstance(shape, (list, tuple)): + raise TypeError(f"`shape` must be a list or tuple, got {type(shape)}") + ndim = len(shape) + if ndim < 1 or ndim > 3: + raise ValueError("LocalAccessor must have dimension between one and three") + for s in shape: + if not isinstance(s, numbers.Integral): + raise TypeError("LocalAccessor shape must be a sequence of integers") + if s < 0: + raise ValueError("LocalAccessor dimensions must be non-negative") + self.lacc.ndim = ndim + self.lacc.dim0 = shape[0] + self.lacc.dim1 = shape[1] if ndim > 1 else 1 + self.lacc.dim2 = shape[2] if ndim > 2 else 1 + + if dtype == 'i1': + self.lacc.dpctl_type_id = _arg_data_type._INT8_T + elif dtype == 'u1': + self.lacc.dpctl_type_id = _arg_data_type._UINT8_T + elif dtype == 'i2': + self.lacc.dpctl_type_id = _arg_data_type._INT16_T + elif dtype == 'u2': + self.lacc.dpctl_type_id = _arg_data_type._UINT16_T + elif dtype == 'i4': + self.lacc.dpctl_type_id = _arg_data_type._INT32_T + elif dtype == 'u4': + self.lacc.dpctl_type_id = _arg_data_type._UINT32_T + elif dtype == 'i8': + self.lacc.dpctl_type_id = _arg_data_type._INT64_T + elif dtype == 'u8': + self.lacc.dpctl_type_id = _arg_data_type._UINT64_T + elif dtype == 'f4': + self.lacc.dpctl_type_id = _arg_data_type._FLOAT + elif dtype == 'f8': + self.lacc.dpctl_type_id = _arg_data_type._DOUBLE + else: + raise ValueError(f"Unrecognized type value: '{dtype}'") + + def __repr__(self): + return f"LocalAccessor({self.lacc.ndim})" + + cdef size_t addressof(self): + """ + Returns the address of the _md_local_accessor for this LocalAccessor + cast to ``size_t``. + """ + return &self.lacc + + cdef class _kernel_arg_type: """ An enumeration of supported kernel argument types in @@ -865,6 +955,9 @@ cdef class SyclQueue(_SyclQueue): elif isinstance(arg, WorkGroupMemory): kargs[idx] = (arg._ref) kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY + elif isinstance(arg, LocalAccessor): + kargs[idx] = ((arg).addressof()) + kargty[idx] = _arg_data_type._LOCAL_ACCESSOR else: ret = -1 return ret diff --git a/dpctl/tests/input_files/local_accessor_kernel_fp64.spv b/dpctl/tests/input_files/local_accessor_kernel_fp64.spv new file mode 100644 index 0000000000..ffc220268a Binary files /dev/null and b/dpctl/tests/input_files/local_accessor_kernel_fp64.spv differ diff --git a/dpctl/tests/input_files/local_accessor_kernel_inttys_fp32.spv b/dpctl/tests/input_files/local_accessor_kernel_inttys_fp32.spv new file mode 100644 index 0000000000..3e2d145ad8 Binary files /dev/null and b/dpctl/tests/input_files/local_accessor_kernel_inttys_fp32.spv differ diff --git a/dpctl/tests/test_sycl_kernel_submit.py b/dpctl/tests/test_sycl_kernel_submit.py index 9575e228f2..e46c4f1760 100644 --- a/dpctl/tests/test_sycl_kernel_submit.py +++ b/dpctl/tests/test_sycl_kernel_submit.py @@ -18,6 +18,7 @@ """ import ctypes +import os import numpy as np import pytest @@ -279,3 +280,42 @@ def test_kernel_arg_type(): _check_kernel_arg_type_instance(kernel_arg_type.dpctl_void_ptr) _check_kernel_arg_type_instance(kernel_arg_type.dpctl_local_accessor) _check_kernel_arg_type_instance(kernel_arg_type.dpctl_work_group_memory) + + +def get_spirv_abspath(fn): + curr_dir = os.path.dirname(os.path.abspath(__file__)) + spirv_file = os.path.join(curr_dir, "input_files", fn) + return spirv_file + + +# the process for generating the .spv files in this test is documented in +# libsyclinterface/tests/test_sycl_queue_submit_local_accessor_arg.cpp +# in a comment starting on line 123 +def test_submit_local_accessor_arg(): + try: + q = dpctl.SyclQueue("level_zero") + except dpctl.SyclQueueCreationError: + pytest.skip("OpenCL queue could not be created") + fn = get_spirv_abspath("local_accessor_kernel_inttys_fp32.spv") + with open(fn, "br") as f: + spirv_bytes = f.read() + prog = dpctl_prog.create_program_from_spirv(q, spirv_bytes) + krn = prog.get_sycl_kernel("_ZTS14SyclKernel_SLMIlE") + lws = 32 + gws = lws * 10 + x = dpt.ones(gws, dtype="i8") + x.sycl_queue.wait() + try: + e = q.submit( + krn, + [x.usm_data, dpctl.LocalAccessor("i8", (lws,))], + [gws], + [lws], + ) + e.wait() + except dpctl._sycl_queue.SyclKernelSubmitError: + pytest.skip(f"Kernel submission failed for device {q.sycl_device}") + expected = dpt.arange(1, x.size + 1, dtype=x.dtype, device=x.device) * ( + 2 * lws + ) + assert dpt.all(x == expected)