Skip to content

Commit b1706fe

Browse files
committed
Support compilation from SYCL source code
1 parent 80213b4 commit b1706fe

12 files changed

+786
-9
lines changed

dpctl/_backend.pxd

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,12 @@ cdef extern from "syclinterface/dpctl_sycl_device_interface.h":
287287
_peer_access PT)
288288
cdef void DPCTLDevice_EnablePeerAccess(const DPCTLSyclDeviceRef DRef,
289289
const DPCTLSyclDeviceRef PDRef)
290-
291290
cdef void DPCTLDevice_DisablePeerAccess(const DPCTLSyclDeviceRef DRef,
292291
const DPCTLSyclDeviceRef PDRef)
292+
cdef bool DPCTLDevice_CanCompileSPIRV(const DPCTLSyclDeviceRef DRef)
293+
cdef bool DPCTLDevice_CanCompileOpenCL(const DPCTLSyclDeviceRef DRef)
294+
cdef bool DPCTLDevice_CanCompileSYCL(const DPCTLSyclDeviceRef DRef)
295+
293296

294297
cdef extern from "syclinterface/dpctl_sycl_device_manager.h":
295298
cdef DPCTLDeviceVectorRef DPCTLDeviceVector_CreateFromArray(
@@ -452,6 +455,43 @@ cdef extern from "syclinterface/dpctl_sycl_kernel_bundle_interface.h":
452455
cdef DPCTLSyclKernelBundleRef DPCTLKernelBundle_Copy(
453456
const DPCTLSyclKernelBundleRef KBRef)
454457

458+
cdef struct DPCTLBuildOptionList
459+
cdef struct DPCTLKernelNameList
460+
cdef struct DPCTLVirtualHeaderList
461+
ctypedef DPCTLBuildOptionList* DPCTLBuildOptionListRef
462+
ctypedef DPCTLKernelNameList* DPCTLKernelNameListRef
463+
ctypedef DPCTLVirtualHeaderList* DPCTLVirtualHeaderListRef
464+
465+
cdef DPCTLBuildOptionListRef DPCTLBuildOptionList_Create()
466+
cdef void DPCTLBuildOptionList_Delete(DPCTLBuildOptionListRef Ref)
467+
cdef void DPCTLBuildOptionList_Append(DPCTLBuildOptionListRef Ref,
468+
const char *Option)
469+
470+
cdef DPCTLKernelNameListRef DPCTLKernelNameList_Create()
471+
cdef void DPCTLKernelNameList_Delete(DPCTLKernelNameListRef Ref)
472+
cdef void DPCTLKernelNameList_Append(DPCTLKernelNameListRef Ref,
473+
const char *Option)
474+
475+
cdef DPCTLVirtualHeaderListRef DPCTLVirtualHeaderList_Create()
476+
cdef void DPCTLVirtualHeaderList_Delete(DPCTLVirtualHeaderListRef Ref)
477+
cdef void DPCTLVirtualHeaderList_Append(DPCTLVirtualHeaderListRef Ref,
478+
const char *Name,
479+
const char *Content)
480+
481+
cdef DPCTLSyclKernelBundleRef DPCTLKernelBundle_CreateFromSYCLSource(
482+
const DPCTLSyclContextRef Ctx,
483+
const DPCTLSyclDeviceRef Dev,
484+
const char *Source,
485+
DPCTLVirtualHeaderListRef Headers,
486+
DPCTLKernelNameListRef Names,
487+
DPCTLBuildOptionListRef BuildOptions)
488+
489+
cdef DPCTLSyclKernelRef DPCTLKernelBundle_GetSyclKernel(DPCTLSyclKernelBundleRef KBRef,
490+
const char *KernelName)
491+
492+
cdef bool DPCTLKernelBundle_HasSyclKernel(DPCTLSyclKernelBundleRef KBRef,
493+
const char *KernelName);
494+
455495

456496
cdef extern from "syclinterface/dpctl_sycl_queue_interface.h":
457497
ctypedef struct _md_local_accessor "MDLocalAccessor":

dpctl/_sycl_device.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,4 @@ cdef public api class SyclDevice(_SyclDevice) [
6161
cdef int get_overall_ordinal(self)
6262
cdef int get_backend_ordinal(self)
6363
cdef int get_backend_and_device_type_ordinal(self)
64+
cpdef bint can_compile(self, str language)

dpctl/_sycl_device.pyx

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ from ._backend cimport ( # noqa: E211
2626
DPCTLDefaultSelector_Create,
2727
DPCTLDevice_AreEq,
2828
DPCTLDevice_CanAccessPeer,
29+
DPCTLDevice_CanCompileOpenCL,
30+
DPCTLDevice_CanCompileSPIRV,
31+
DPCTLDevice_CanCompileSYCL,
2932
DPCTLDevice_Copy,
3033
DPCTLDevice_CreateFromSelector,
3134
DPCTLDevice_CreateSubDevicesByAffinity,
@@ -2367,6 +2370,35 @@ cdef class SyclDevice(_SyclDevice):
23672370
raise ValueError("device could not be found")
23682371
return dev_id
23692372

2373+
cpdef bint can_compile(self, str language):
2374+
"""
2375+
Check whether it is possible to create an executable kernel_bundle
2376+
for this device from the given source language.
2377+
2378+
Parameters:
2379+
language
2380+
Input language. Possible values are "spirv" for SPIR-V binary
2381+
files, "opencl" for OpenCL C device code and "sycl" for SYCL
2382+
device code.
2383+
2384+
Returns:
2385+
bool:
2386+
True if compilation is supported, False otherwise.
2387+
2388+
Raises:
2389+
ValueError:
2390+
If an unknown source language is used.
2391+
"""
2392+
if language == "spirv" or language == "spv":
2393+
return DPCTLDevice_CanCompileSYCL(self._device_ref)
2394+
if language == "opencl" or language == "ocl":
2395+
return DPCTLDevice_CanCompileOpenCL(self._device_ref)
2396+
if language == "sycl":
2397+
return DPCTLDevice_CanCompileSYCL(self._device_ref)
2398+
2399+
raise ValueError(f"Unknown source language {language}")
2400+
2401+
23702402

23712403
cdef api DPCTLSyclDeviceRef SyclDevice_GetDeviceRef(SyclDevice dev):
23722404
"""

dpctl/program/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
SyclProgramCompilationError,
2727
create_program_from_source,
2828
create_program_from_spirv,
29+
create_program_from_sycl_source,
2930
)
3031

3132
__all__ = [
3233
"create_program_from_source",
3334
"create_program_from_spirv",
35+
"create_program_from_sycl_source",
3436
"SyclKernel",
3537
"SyclProgram",
3638
"SyclProgramCompilationError",

dpctl/program/_program.pxd

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,17 @@ cdef api class SyclProgram [object PySyclProgramObject, type PySyclProgramType]:
4949
binary file.
5050
"""
5151
cdef DPCTLSyclKernelBundleRef _program_ref
52+
cdef bint _is_sycl_source
5253

5354
@staticmethod
54-
cdef SyclProgram _create (DPCTLSyclKernelBundleRef pref)
55+
cdef SyclProgram _create (DPCTLSyclKernelBundleRef pref, bint _is_sycl_source)
5556
cdef DPCTLSyclKernelBundleRef get_program_ref (self)
5657
cpdef SyclKernel get_sycl_kernel(self, str kernel_name)
5758

5859

5960
cpdef create_program_from_source (SyclQueue q, unicode source, unicode copts=*)
6061
cpdef create_program_from_spirv (SyclQueue q, const unsigned char[:] IL,
6162
unicode copts=*)
63+
cpdef create_program_from_sycl_source(SyclQueue q, unicode source,
64+
list headers=*, list registered_names=*,
65+
list copts=*)

dpctl/program/_program.pyx

Lines changed: 126 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ a OpenCL source string or a SPIR-V binary file.
2828
from libc.stdint cimport uint32_t
2929

3030
from dpctl._backend cimport ( # noqa: E211, E402;
31+
DPCTLBuildOptionList_Append,
32+
DPCTLBuildOptionList_Create,
33+
DPCTLBuildOptionList_Delete,
34+
DPCTLBuildOptionListRef,
3135
DPCTLKernel_Copy,
3236
DPCTLKernel_Delete,
3337
DPCTLKernel_GetCompileNumSubGroups,
@@ -41,13 +45,24 @@ from dpctl._backend cimport ( # noqa: E211, E402;
4145
DPCTLKernelBundle_Copy,
4246
DPCTLKernelBundle_CreateFromOCLSource,
4347
DPCTLKernelBundle_CreateFromSpirv,
48+
DPCTLKernelBundle_CreateFromSYCLSource,
4449
DPCTLKernelBundle_Delete,
4550
DPCTLKernelBundle_GetKernel,
51+
DPCTLKernelBundle_GetSyclKernel,
4652
DPCTLKernelBundle_HasKernel,
53+
DPCTLKernelBundle_HasSyclKernel,
54+
DPCTLKernelNameList_Append,
55+
DPCTLKernelNameList_Create,
56+
DPCTLKernelNameList_Delete,
57+
DPCTLKernelNameListRef,
4758
DPCTLSyclContextRef,
4859
DPCTLSyclDeviceRef,
4960
DPCTLSyclKernelBundleRef,
5061
DPCTLSyclKernelRef,
62+
DPCTLVirtualHeaderList_Append,
63+
DPCTLVirtualHeaderList_Create,
64+
DPCTLVirtualHeaderList_Delete,
65+
DPCTLVirtualHeaderListRef,
5166
)
5267

5368
__all__ = [
@@ -196,9 +211,10 @@ cdef class SyclProgram:
196211
"""
197212

198213
@staticmethod
199-
cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef):
214+
cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef, bint is_sycl_source):
200215
cdef SyclProgram ret = SyclProgram.__new__(SyclProgram)
201216
ret._program_ref = KBRef
217+
ret._is_sycl_source = is_sycl_source
202218
return ret
203219

204220
def __dealloc__(self):
@@ -209,13 +225,19 @@ cdef class SyclProgram:
209225

210226
cpdef SyclKernel get_sycl_kernel(self, str kernel_name):
211227
name = kernel_name.encode("utf8")
228+
if self._is_sycl_source:
229+
return SyclKernel._create(
230+
DPCTLKernelBundle_GetSyclKernel(self._program_ref, name),
231+
kernel_name)
212232
return SyclKernel._create(
213233
DPCTLKernelBundle_GetKernel(self._program_ref, name),
214234
kernel_name
215235
)
216236

217237
def has_sycl_kernel(self, str kernel_name):
218238
name = kernel_name.encode("utf8")
239+
if self._is_sycl_source:
240+
return DPCTLKernelBundle_HasSyclKernel(self._program_ref, name)
219241
return DPCTLKernelBundle_HasKernel(self._program_ref, name)
220242

221243
def addressof_ref(self):
@@ -271,7 +293,7 @@ cpdef create_program_from_source(SyclQueue q, str src, str copts=""):
271293
if KBref is NULL:
272294
raise SyclProgramCompilationError()
273295

274-
return SyclProgram._create(KBref)
296+
return SyclProgram._create(KBref, False)
275297

276298

277299
cpdef create_program_from_spirv(SyclQueue q, const unsigned char[:] IL,
@@ -317,7 +339,107 @@ cpdef create_program_from_spirv(SyclQueue q, const unsigned char[:] IL,
317339
if KBref is NULL:
318340
raise SyclProgramCompilationError()
319341

320-
return SyclProgram._create(KBref)
342+
return SyclProgram._create(KBref, False)
343+
344+
345+
cpdef create_program_from_sycl_source(SyclQueue q, unicode source, list headers=[], list registered_names=[], list copts=[]):
346+
"""
347+
Creates an executable SYCL kernel_bundle from SYCL source code.
348+
349+
This uses the DPC++ ``kernel_compiler`` extension to create a
350+
``sycl::kernel_bundle<sycl::bundle_state::executable>`` object from
351+
SYCL source code.
352+
353+
Parameters:
354+
q (:class:`dpctl.SyclQueue`)
355+
The :class:`dpctl.SyclQueue` for which the
356+
:class:`.SyclProgram` is going to be built.
357+
source (unicode)
358+
SYCL source code string.
359+
headers (list)
360+
Optional list of virtual headers, where each entry in the list
361+
needs to be a tuple of header name and header content. See the
362+
documentation of the ``include_files`` property in the DPC++
363+
``kernel_compiler`` extension for more information.
364+
Default: []
365+
registered_names (list, optional)
366+
Optional list of kernel names to register. See the
367+
documentation of the ``registered_names`` property in the DPC++
368+
``kernel_compiler`` extension for more information.
369+
Default: []
370+
copts (list)
371+
Optional list of compilation flags that will be used
372+
when compiling the program. Default: ``""``.
373+
374+
Returns:
375+
program (:class:`.SyclProgram`)
376+
A :class:`.SyclProgram` object wrapping the
377+
``sycl::kernel_bundle<sycl::bundle_state::executable>``
378+
returned by the C API.
379+
380+
Raises:
381+
SyclProgramCompilationError
382+
If a SYCL kernel bundle could not be created.
383+
"""
384+
cdef DPCTLSyclKernelBundleRef KBref
385+
cdef DPCTLSyclContextRef CRef = q.get_sycl_context().get_context_ref()
386+
cdef DPCTLSyclDeviceRef DRef = q.get_sycl_device().get_device_ref()
387+
cdef bytes bSrc = source.encode('utf8')
388+
cdef const char *Src = <const char*>bSrc
389+
cdef DPCTLBuildOptionListRef BuildOpts = DPCTLBuildOptionList_Create()
390+
cdef bytes bOpt
391+
cdef const char* sOpt
392+
cdef bytes bName
393+
cdef const char* sName
394+
cdef bytes bContent
395+
cdef const char* sContent
396+
for opt in copts:
397+
if not isinstance(opt, unicode):
398+
DPCTLBuildOptionList_Delete(BuildOpts)
399+
raise SyclProgramCompilationError()
400+
bOpt = opt.encode('utf8')
401+
sOpt = <const char*>bOpt
402+
DPCTLBuildOptionList_Append(BuildOpts, sOpt)
403+
404+
cdef DPCTLKernelNameListRef KernelNames = DPCTLKernelNameList_Create()
405+
for name in registered_names:
406+
if not isinstance(name, unicode):
407+
DPCTLBuildOptionList_Delete(BuildOpts)
408+
DPCTLKernelNameList_Delete(KernelNames)
409+
raise SyclProgramCompilationError()
410+
bName = name.encode('utf8')
411+
sName = <const char*>bName
412+
DPCTLKernelNameList_Append(KernelNames, sName)
413+
414+
415+
cdef DPCTLVirtualHeaderListRef VirtualHeaders = DPCTLVirtualHeaderList_Create()
416+
for name, content in headers:
417+
if not isinstance(name, unicode) or not isinstance(content, unicode):
418+
DPCTLBuildOptionList_Delete(BuildOpts)
419+
DPCTLKernelNameList_Delete(KernelNames)
420+
DPCTLVirtualHeaderList_Delete(VirtualHeaders)
421+
raise SyclProgramCompilationError()
422+
bName = name.encode('utf8')
423+
sName = <const char*>bName
424+
bContent = content.encode('utf8')
425+
sContent = <const char*>bContent
426+
DPCTLVirtualHeaderList_Append(VirtualHeaders, sName, sContent)
427+
428+
KBref = DPCTLKernelBundle_CreateFromSYCLSource(CRef, DRef, Src,
429+
VirtualHeaders, KernelNames,
430+
BuildOpts)
431+
432+
if KBref is NULL:
433+
DPCTLBuildOptionList_Delete(BuildOpts)
434+
DPCTLKernelNameList_Delete(KernelNames)
435+
DPCTLVirtualHeaderList_Delete(VirtualHeaders)
436+
raise SyclProgramCompilationError()
437+
438+
DPCTLBuildOptionList_Delete(BuildOpts)
439+
DPCTLKernelNameList_Delete(KernelNames)
440+
DPCTLVirtualHeaderList_Delete(VirtualHeaders)
441+
442+
return SyclProgram._create(KBref, True)
321443

322444

323445
cdef api DPCTLSyclKernelBundleRef SyclProgram_GetKernelBundleRef(
@@ -336,4 +458,4 @@ cdef api SyclProgram SyclProgram_Make(DPCTLSyclKernelBundleRef KBRef):
336458
reference.
337459
"""
338460
cdef DPCTLSyclKernelBundleRef copied_KBRef = DPCTLKernelBundle_Copy(KBRef)
339-
return SyclProgram._create(copied_KBRef)
461+
return SyclProgram._create(copied_KBRef, False)

0 commit comments

Comments
 (0)