Skip to content

Commit 50774dc

Browse files
committed
Add support for work_group_memory extension
Extend kernel argument handling to add support for the work_group_memory extension, allowing users to dynamically allocate local memory for a kernel. Signed-off-by: Lukas Sommer <lukas.sommer@codeplay.com>
1 parent d84cb16 commit 50774dc

File tree

11 files changed

+151
-1
lines changed

11 files changed

+151
-1
lines changed

dpctl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,4 @@ add_subdirectory(program)
207207
add_subdirectory(memory)
208208
add_subdirectory(tensor)
209209
add_subdirectory(utils)
210+
add_subdirectory(experimental)

dpctl/_backend.pxd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ cdef extern from "syclinterface/dpctl_sycl_enum_types.h":
6969
_FLOAT 'DPCTL_FLOAT32_T',
7070
_DOUBLE 'DPCTL_FLOAT64_T',
7171
_VOID_PTR 'DPCTL_VOID_PTR',
72-
_LOCAL_ACCESSOR 'DPCTL_LOCAL_ACCESSOR'
72+
_LOCAL_ACCESSOR 'DPCTL_LOCAL_ACCESSOR',
73+
_WORK_GROUP_MEMORY 'DPCTL_WORK_GROUP_MEMORY'
7374

7475
ctypedef enum _queue_property_type 'DPCTLQueuePropertyType':
7576
_DEFAULT_PROPERTY 'DPCTL_DEFAULT_PROPERTY'

dpctl/_sycl_queue.pyx

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ from ._backend cimport ( # noqa: E211
5959
_queue_property_type,
6060
)
6161
from .memory._memory cimport _Memory
62+
from .experimental._work_group_memory cimport WorkGroupMemory
6263

6364
import ctypes
6465

@@ -242,6 +243,15 @@ cdef class _kernel_arg_type:
242243
_arg_data_type._LOCAL_ACCESSOR
243244
)
244245

246+
@property
247+
def dpctl_work_group_memory(self):
248+
cdef str p_name = "dpctl_work_group_memory"
249+
return kernel_arg_type_attribute(
250+
self._name,
251+
p_name,
252+
_arg_data_type._WORK_GROUP_MEMORY
253+
)
254+
245255

246256
kernel_arg_type = _kernel_arg_type()
247257

@@ -824,6 +834,9 @@ cdef class SyclQueue(_SyclQueue):
824834
elif isinstance(arg, _Memory):
825835
kargs[idx]= <void*>(<size_t>arg._pointer)
826836
kargty[idx] = _arg_data_type._VOID_PTR
837+
elif isinstance(arg, WorkGroupMemory):
838+
kargs[idx] = <void*>(<size_t>arg.nbytes)
839+
kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY
827840
else:
828841
ret = -1
829842
return ret

dpctl/experimental/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
file(GLOB _cython_sources *.pyx)
2+
foreach(_cy_file ${_cython_sources})
3+
get_filename_component(_trgt ${_cy_file} NAME_WLE)
4+
build_dpctl_ext(${_trgt} ${_cy_file} "dpctl/experimental" RELATIVE_PATH "..")
5+
target_include_directories(${_trgt} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include)
6+
target_link_libraries(DpctlCAPI INTERFACE ${_trgt}_headers)
7+
endforeach()

dpctl/experimental/__init__.pxd

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# distutils: language = c++
18+
# cython: language_level=3
19+
20+
"""This file declares the extension types and functions for the Cython API
21+
implemented in dpctl.experimental.*.pyx.
22+
"""
23+
24+
25+
from dpctl.experimental._work_group_memory cimport *

dpctl/experimental/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""
18+
**Data Parallel Control Experimental" provides Python objects to interface
19+
with different experimental SYCL language extensions defined by the DPC++
20+
SYCL implementation.
21+
"""
22+
23+
from ._work_group_memory import (
24+
WorkGroupMemory,
25+
)
26+
27+
__all__ = [
28+
"WorkGroupMemory",
29+
]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# distutils: language = c++
18+
# cython: language_level=3
19+
20+
cdef public api class WorkGroupMemory [object PyWorkGroupMemoryObject, type PyWorkGroupMemoryType]:
21+
cdef Py_ssize_t nbytes
22+
23+
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# distutils: language = c++
18+
# cython: language_level=3
19+
# cython: linetrace=True
20+
21+
cdef class WorkGroupMemory:
22+
"""
23+
WorkGroupMemory(nbytes)
24+
Python class representing the ``work_group_memory`` class from the
25+
Workgroup Memory oneAPI SYCL extension for low-overhead allocation of local
26+
memory shared by the workitems in a workgroup.
27+
28+
Args:
29+
nbytes (int)
30+
number of bytes to allocate in local memory.
31+
Expected to be positive.
32+
"""
33+
def __cinit__(self, Py_ssize_t nbytes):
34+
self.nbytes = nbytes
35+
36+
property nbytes:
37+
"""Local memory size in bytes."""
38+
def __get__(self):
39+
return self.nbytes
40+
41+

dpctl/tests/test_sycl_kernel_submit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,4 @@ def test_kernel_arg_type():
278278
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_float64)
279279
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_void_ptr)
280280
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_local_accessor)
281+
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_work_group_memory)

libsyclinterface/include/syclinterface/dpctl_sycl_enum_types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ typedef enum
100100
DPCTL_FLOAT64_T,
101101
DPCTL_VOID_PTR,
102102
DPCTL_LOCAL_ACCESSOR,
103+
DPCTL_WORK_GROUP_MEMORY,
103104
DPCTL_UNSUPPORTED_KERNEL_ARG
104105
} DPCTLKernelArgType;
105106

libsyclinterface/source/dpctl_sycl_queue_interface.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,14 @@ bool set_kernel_arg(handler &cgh,
216216
case DPCTL_LOCAL_ACCESSOR:
217217
arg_set = set_local_accessor_arg(cgh, idx, (MDLocalAccessor *)Arg);
218218
break;
219+
case DPCTL_WORK_GROUP_MEMORY:
220+
{
221+
size_t num_bytes = reinterpret_cast<std::uintptr_t>(Arg);
222+
sycl::ext::oneapi::experimental::work_group_memory<char[]> mem{
223+
num_bytes, cgh};
224+
cgh.set_arg(idx, mem);
225+
break;
226+
}
219227
default:
220228
arg_set = false;
221229
break;

0 commit comments

Comments
 (0)