Skip to content

Commit 77e05dd

Browse files
committed
hook specialization constants into kernel bundle interface
1 parent 2493f89 commit 77e05dd

5 files changed

Lines changed: 118 additions & 16 deletions

File tree

dpctl/_backend.pxd

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,9 @@ cdef extern from "syclinterface/dpctl_sycl_kernel_bundle_interface.h":
440440
const DPCTLSyclDeviceRef Dev,
441441
const void *IL,
442442
size_t Length,
443-
const char *CompileOpts)
443+
const char *CompileOpts,
444+
size_t NumSpecConsts,
445+
const _spec_const *SpecConsts)
444446
cdef DPCTLSyclKernelBundleRef DPCTLKernelBundle_CreateFromOCLSource(
445447
const DPCTLSyclContextRef Ctx,
446448
const DPCTLSyclDeviceRef Dev,

dpctl/program/_program.pxd

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ cpdef create_kernel_bundle_from_source (
6363
SyclQueue q, unicode source, unicode copts=*
6464
)
6565
cpdef create_kernel_bundle_from_spirv (
66-
SyclQueue q, const unsigned char[:] IL, unicode copts=*
66+
SyclQueue q,
67+
const unsigned char[:] IL,
68+
unicode copts=*,
69+
list specializations=*,
6770
)
6871
cpdef create_program_from_source (SyclQueue q, unicode source, unicode copts=*)
6972
cpdef create_program_from_spirv (

dpctl/program/_program.pyx

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ from cpython.buffer cimport (
3636
)
3737
from cpython.bytes cimport PyBytes_FromStringAndSize
3838
from libc.stdint cimport uint32_t
39+
from libc.stdlib cimport free, malloc
3940
from libc.string cimport memcmp
4041

4142
import warnings
@@ -469,7 +470,10 @@ cpdef create_kernel_bundle_from_source(SyclQueue q, str src, str copts=""):
469470

470471

471472
cpdef create_kernel_bundle_from_spirv(
472-
SyclQueue q, const unsigned char[:] IL, str copts=""
473+
SyclQueue q,
474+
const unsigned char[:] IL,
475+
str copts="",
476+
list specializations=None,
473477
):
474478
"""
475479
Creates a Sycl interoperability kernel bundle from an SPIR-V binary.
@@ -487,7 +491,9 @@ cpdef create_kernel_bundle_from_spirv(
487491
copts (str, optional)
488492
Optional compilation flags that will be used
489493
when compiling the kernel bundle. Default: ``""``.
490-
494+
specializations (list, optional)
495+
A list of :class:`.SpecializationConstant` objects to be used
496+
when creating the kernel bundle. Default: ``None``.
491497
Returns:
492498
kernel_bundle (:class:`.SyclKernelBundle`)
493499
A :class:`.SyclKernelBundle` object wrapping the
@@ -506,11 +512,44 @@ cpdef create_kernel_bundle_from_spirv(
506512
cdef size_t length = IL.shape[0]
507513
cdef bytes bCOpts = copts.encode("utf8")
508514
cdef const char *COpts = <const char*>bCOpts
509-
KBref = DPCTLKernelBundle_CreateFromSpirv(
510-
CRef, DRef, <const void*>dIL, length, COpts
511-
)
512-
if KBref is NULL:
513-
raise SyclKernelBundleCompilationError()
515+
cdef size_t num_spconsts
516+
cdef _spec_const *spconsts
517+
cdef SpecializationConstant spconst
518+
519+
if specializations is not None:
520+
num_spconsts = len(specializations)
521+
spconsts = <_spec_const *>(
522+
malloc(num_spconsts * sizeof(_spec_const))
523+
)
524+
if spconsts == NULL:
525+
raise MemoryError(
526+
"Failed to allocate memory for specialization constants."
527+
)
528+
for i, spconst in enumerate(specializations):
529+
if not isinstance(spconst, SpecializationConstant):
530+
free(spconsts)
531+
raise TypeError(
532+
"All items in specializations must be of type "
533+
f"`SpecializationConstant`, got {type(spconst)}"
534+
)
535+
spconsts[i] = spconst._spec_const
536+
else:
537+
num_spconsts = 0
538+
spconsts = NULL
539+
try:
540+
KBref = DPCTLKernelBundle_CreateFromSpirv(
541+
CRef,
542+
DRef,
543+
<const void*>dIL,
544+
length, COpts,
545+
num_spconsts,
546+
spconsts,
547+
)
548+
if KBref is NULL:
549+
raise SyclKernelBundleCompilationError()
550+
finally:
551+
if spconsts != NULL:
552+
free(spconsts)
514553

515554
return SyclKernelBundle._create(KBref)
516555

libsyclinterface/include/syclinterface/dpctl_sycl_kernel_bundle_interface.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ typedef struct DPCTLSpecConstTy
5858
* @param Length The size of the IL binary in bytes.
5959
* @param CompileOpts Optional compiler flags used when compiling the
6060
* SPIR-V binary.
61+
* @param NumSpecConsts The number of specialization constants.
62+
* @param SpecConsts An array of specialization constants.
6163
* @return A new SyclKernelBundleRef pointer if the kernel_bundle creation
6264
* succeeded, else returns NULL.
6365
* @ingroup KernelBundleInterface
@@ -68,7 +70,9 @@ DPCTLKernelBundle_CreateFromSpirv(__dpctl_keep const DPCTLSyclContextRef Ctx,
6870
__dpctl_keep const DPCTLSyclDeviceRef Dev,
6971
__dpctl_keep const void *IL,
7072
size_t Length,
71-
const char *CompileOpts);
73+
const char *CompileOpts,
74+
size_t NumSpecConsts,
75+
const DPCTLSpecConst *SpecConsts);
7276

7377
/*!
7478
* @brief Create a Sycl kernel bundle from an OpenCL kernel source string.

libsyclinterface/source/dpctl_sycl_kernel_bundle_interface.cpp

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "dpctl_error_handlers.h"
3232
#include "dpctl_sycl_type_casters.hpp"
3333
#include <CL/cl.h> /* OpenCL headers */
34+
#include <cstdint>
3435
#include <sstream>
3536
#include <stddef.h>
3637
#include <sycl/backend/opencl.hpp>
@@ -170,6 +171,21 @@ std::string _GetErrorCode_ocl_impl(cl_int code)
170171
}
171172
}
172173

174+
typedef cl_int (*clSetProgramSpecializationConstantFT)(cl_program,
175+
cl_uint,
176+
size_t,
177+
const void *);
178+
const char *clSetProgramSpecializationConstant_Name =
179+
"clSetProgramSpecializationConstant";
180+
clSetProgramSpecializationConstantFT get_clSetProgramSpecializationConstant()
181+
{
182+
static auto st_clSetProgramSpecializationConstantF =
183+
cl_loader::get().getSymbol<clSetProgramSpecializationConstantFT>(
184+
clSetProgramSpecializationConstant_Name);
185+
186+
return st_clSetProgramSpecializationConstantF;
187+
}
188+
173189
DPCTLSyclKernelBundleRef
174190
_CreateKernelBundle_common_ocl_impl(cl_program clProgram,
175191
const context &ctx,
@@ -235,7 +251,9 @@ _CreateKernelBundleWithIL_ocl_impl(const context &ctx,
235251
const device &dev,
236252
const void *IL,
237253
size_t il_length,
238-
const char *CompileOpts)
254+
const char *CompileOpts,
255+
size_t NumSpecConsts,
256+
const DPCTLSpecConst *SpecConsts)
239257
{
240258
auto clCreateProgramWithILF = get_clCreateProgramWithIL();
241259
if (clCreateProgramWithILF == nullptr) {
@@ -257,6 +275,22 @@ _CreateKernelBundleWithIL_ocl_impl(const context &ctx,
257275
return nullptr;
258276
}
259277

278+
if (SpecConsts != nullptr && NumSpecConsts > 0) {
279+
auto clSetProgramSpecConstF = get_clSetProgramSpecializationConstant();
280+
if (clSetProgramSpecConstF) {
281+
for (size_t i = 0; i < NumSpecConsts; ++i) {
282+
clSetProgramSpecConstF(clProgram, SpecConsts[i].id,
283+
SpecConsts[i].size, SpecConsts[i].value);
284+
}
285+
}
286+
else {
287+
error_handler("clSetProgramSpecializationConstant is not available "
288+
"in the OpenCL implementation.",
289+
__FILE__, __func__, __LINE__);
290+
return nullptr;
291+
}
292+
}
293+
260294
return _CreateKernelBundle_common_ocl_impl(clProgram, ctx, dev,
261295
CompileOpts);
262296
}
@@ -428,7 +462,9 @@ _CreateKernelBundleWithIL_ze_impl(const context &SyclCtx,
428462
const device &SyclDev,
429463
const void *IL,
430464
size_t il_length,
431-
const char *CompileOpts)
465+
const char *CompileOpts,
466+
size_t NumSpecConsts,
467+
const DPCTLSpecConst *SpecConsts)
432468
{
433469
auto zeModuleCreateFn = get_zeModuleCreate();
434470
if (zeModuleCreateFn == nullptr) {
@@ -444,8 +480,22 @@ _CreateKernelBundleWithIL_ze_impl(const context &SyclCtx,
444480
ZeDevice = get_native<ze_be>(SyclDev);
445481

446482
// Specialization constants are not supported by DPCTL at the moment
483+
std::vector<std::uint32_t> spec_ids;
484+
std::vector<const void *> spec_values;
485+
486+
if (SpecConsts != nullptr && NumSpecConsts > 0) {
487+
spec_ids.reserve(NumSpecConsts);
488+
spec_values.reserve(NumSpecConsts);
489+
for (size_t i = 0; i < NumSpecConsts; ++i) {
490+
spec_ids.push_back(SpecConsts[i].id);
491+
spec_values.push_back(SpecConsts[i].value);
492+
}
493+
}
447494
ze_module_constants_t ZeSpecConstants = {};
448-
ZeSpecConstants.numConstants = 0;
495+
ZeSpecConstants.numConstants = static_cast<std::uint32_t>(NumSpecConsts);
496+
ZeSpecConstants.pConstantIds = spec_ids.empty() ? nullptr : spec_ids.data();
497+
ZeSpecConstants.pConstantValues =
498+
spec_values.empty() ? nullptr : spec_values.data();
449499

450500
// Populate the Level Zero module descriptions
451501
ze_module_desc_t ZeModuleDesc = {};
@@ -583,7 +633,9 @@ DPCTLKernelBundle_CreateFromSpirv(__dpctl_keep const DPCTLSyclContextRef CtxRef,
583633
__dpctl_keep const DPCTLSyclDeviceRef DevRef,
584634
__dpctl_keep const void *IL,
585635
size_t length,
586-
const char *CompileOpts)
636+
const char *CompileOpts,
637+
size_t NumSpecConsts,
638+
const DPCTLSpecConst *SpecConsts)
587639
{
588640
DPCTLSyclKernelBundleRef KBRef = nullptr;
589641
if (!CtxRef) {
@@ -611,12 +663,14 @@ DPCTLKernelBundle_CreateFromSpirv(__dpctl_keep const DPCTLSyclContextRef CtxRef,
611663
switch (BE) {
612664
case backend::opencl:
613665
KBRef = _CreateKernelBundleWithIL_ocl_impl(*SyclCtx, *SyclDev, IL,
614-
length, CompileOpts);
666+
length, CompileOpts,
667+
NumSpecConsts, SpecConsts);
615668
break;
616669
case backend::ext_oneapi_level_zero:
617670
#ifdef DPCTL_ENABLE_L0_PROGRAM_CREATION
618671
KBRef = _CreateKernelBundleWithIL_ze_impl(*SyclCtx, *SyclDev, IL,
619-
length, CompileOpts);
672+
length, CompileOpts,
673+
NumSpecConsts, SpecConsts);
620674
break;
621675
#endif
622676
default:

0 commit comments

Comments
 (0)