3131#include " dpctl_error_handlers.h"
3232#include " dpctl_sycl_type_casters.hpp"
3333#include < CL/cl.h> /* OpenCL headers */
34+ #include < cstdint>
3435#include < sstream>
3536#include < stddef.h>
3637#include < sycl/backend/opencl.hpp>
@@ -170,6 +171,21 @@ std::string _GetErrorCode_ocl_impl(cl_int code)
170171 }
171172}
172173
174+ typedef cl_int (*clSetProgramSpecializationConstantFT)(cl_program,
175+ cl_uint,
176+ size_t ,
177+ const void *);
178+ const char *clSetProgramSpecializationConstant_Name =
179+ " clSetProgramSpecializationConstant" ;
180+ clSetProgramSpecializationConstantFT get_clSetProgramSpecializationConstant ()
181+ {
182+ static auto st_clSetProgramSpecializationConstantF =
183+ cl_loader::get ().getSymbol <clSetProgramSpecializationConstantFT>(
184+ clSetProgramSpecializationConstant_Name);
185+
186+ return st_clSetProgramSpecializationConstantF;
187+ }
188+
173189DPCTLSyclKernelBundleRef
174190_CreateKernelBundle_common_ocl_impl (cl_program clProgram,
175191 const context &ctx,
@@ -235,7 +251,9 @@ _CreateKernelBundleWithIL_ocl_impl(const context &ctx,
235251 const device &dev,
236252 const void *IL,
237253 size_t il_length,
238- const char *CompileOpts)
254+ const char *CompileOpts,
255+ size_t NumSpecConsts,
256+ const DPCTLSpecConst *SpecConsts)
239257{
240258 auto clCreateProgramWithILF = get_clCreateProgramWithIL ();
241259 if (clCreateProgramWithILF == nullptr ) {
@@ -257,6 +275,22 @@ _CreateKernelBundleWithIL_ocl_impl(const context &ctx,
257275 return nullptr ;
258276 }
259277
278+ if (SpecConsts != nullptr && NumSpecConsts > 0 ) {
279+ auto clSetProgramSpecConstF = get_clSetProgramSpecializationConstant ();
280+ if (clSetProgramSpecConstF) {
281+ for (size_t i = 0 ; i < NumSpecConsts; ++i) {
282+ clSetProgramSpecConstF (clProgram, SpecConsts[i].id ,
283+ SpecConsts[i].size , SpecConsts[i].value );
284+ }
285+ }
286+ else {
287+ error_handler (" clSetProgramSpecializationConstant is not available "
288+ " in the OpenCL implementation." ,
289+ __FILE__, __func__, __LINE__);
290+ return nullptr ;
291+ }
292+ }
293+
260294 return _CreateKernelBundle_common_ocl_impl (clProgram, ctx, dev,
261295 CompileOpts);
262296}
@@ -428,7 +462,9 @@ _CreateKernelBundleWithIL_ze_impl(const context &SyclCtx,
428462 const device &SyclDev,
429463 const void *IL,
430464 size_t il_length,
431- const char *CompileOpts)
465+ const char *CompileOpts,
466+ size_t NumSpecConsts,
467+ const DPCTLSpecConst *SpecConsts)
432468{
433469 auto zeModuleCreateFn = get_zeModuleCreate ();
434470 if (zeModuleCreateFn == nullptr ) {
@@ -444,8 +480,22 @@ _CreateKernelBundleWithIL_ze_impl(const context &SyclCtx,
444480 ZeDevice = get_native<ze_be>(SyclDev);
445481
446482 // Specialization constants are not supported by DPCTL at the moment
483+ std::vector<std::uint32_t > spec_ids;
484+ std::vector<const void *> spec_values;
485+
486+ if (SpecConsts != nullptr && NumSpecConsts > 0 ) {
487+ spec_ids.reserve (NumSpecConsts);
488+ spec_values.reserve (NumSpecConsts);
489+ for (size_t i = 0 ; i < NumSpecConsts; ++i) {
490+ spec_ids.push_back (SpecConsts[i].id );
491+ spec_values.push_back (SpecConsts[i].value );
492+ }
493+ }
447494 ze_module_constants_t ZeSpecConstants = {};
448- ZeSpecConstants.numConstants = 0 ;
495+ ZeSpecConstants.numConstants = static_cast <std::uint32_t >(NumSpecConsts);
496+ ZeSpecConstants.pConstantIds = spec_ids.empty () ? nullptr : spec_ids.data ();
497+ ZeSpecConstants.pConstantValues =
498+ spec_values.empty () ? nullptr : spec_values.data ();
449499
450500 // Populate the Level Zero module descriptions
451501 ze_module_desc_t ZeModuleDesc = {};
@@ -583,7 +633,9 @@ DPCTLKernelBundle_CreateFromSpirv(__dpctl_keep const DPCTLSyclContextRef CtxRef,
583633 __dpctl_keep const DPCTLSyclDeviceRef DevRef,
584634 __dpctl_keep const void *IL,
585635 size_t length,
586- const char *CompileOpts)
636+ const char *CompileOpts,
637+ size_t NumSpecConsts,
638+ const DPCTLSpecConst *SpecConsts)
587639{
588640 DPCTLSyclKernelBundleRef KBRef = nullptr ;
589641 if (!CtxRef) {
@@ -611,12 +663,14 @@ DPCTLKernelBundle_CreateFromSpirv(__dpctl_keep const DPCTLSyclContextRef CtxRef,
611663 switch (BE) {
612664 case backend::opencl:
613665 KBRef = _CreateKernelBundleWithIL_ocl_impl (*SyclCtx, *SyclDev, IL,
614- length, CompileOpts);
666+ length, CompileOpts,
667+ NumSpecConsts, SpecConsts);
615668 break ;
616669 case backend::ext_oneapi_level_zero:
617670#ifdef DPCTL_ENABLE_L0_PROGRAM_CREATION
618671 KBRef = _CreateKernelBundleWithIL_ze_impl (*SyclCtx, *SyclDev, IL,
619- length, CompileOpts);
672+ length, CompileOpts,
673+ NumSpecConsts, SpecConsts);
620674 break ;
621675#endif
622676 default :
0 commit comments