@@ -26,7 +26,17 @@ an OpenCL source string or a SPIR-V binary file.
2626
2727"""
2828
29+ from cpython.buffer cimport (
30+ Py_buffer,
31+ PyBUF_ANY_CONTIGUOUS,
32+ PyBUF_SIMPLE,
33+ PyBuffer_Release,
34+ PyObject_CheckBuffer,
35+ PyObject_GetBuffer,
36+ )
37+ from cpython.bytes cimport PyBytes_FromStringAndSize
2938from libc.stdint cimport uint32_t
39+ from libc.string cimport memcmp
3040
3141import warnings
3242
@@ -51,14 +61,20 @@ from dpctl._backend cimport ( # noqa: E211, E402;
5161 DPCTLSyclDeviceRef,
5262 DPCTLSyclKernelBundleRef,
5363 DPCTLSyclKernelRef,
64+ _spec_const,
5465)
5566
67+ import numbers
68+
69+ import numpy as np
70+
5671__all__ = [
5772 " create_kernel_bundle_from_source" ,
5873 " create_kernel_bundle_from_spirv" ,
5974 " SyclKernel" ,
6075 " SyclKernelBundle" ,
6176 " SyclKernelBundleCompilationError" ,
77+ " SpecializationConstant" ,
6278]
6379
6480cdef class SyclKernelBundleCompilationError(Exception ):
@@ -252,6 +268,160 @@ cdef api SyclKernelBundle SyclKernelBundle_Make(DPCTLSyclKernelBundleRef KBRef):
252268 return SyclKernelBundle._create(copied_KBRef)
253269
254270
271+ cdef class SpecializationConstant:
272+ """
273+ SpecializationConstant(spec_id, *args)
274+
275+ Python class representing SYCL specialization constants that can be used
276+ when creating a :class:`dpctl.program.SyclKernelBundle` from SPIR-V.
277+
278+ There are multiple ways to create a :class:`.SpecializationConstant`:
279+
280+ - ``SpecializationConstant(spec_id, obj)``
281+ If the constructor is invoked with a single variadic argument, the
282+ argument is expected to either expose the Python buffer protocol or be
283+ coercible to a NumPy array. If the argument is coercible to a NumPy array
284+ or is one, it must have a supported data type (bool, integral, floating
285+ point, or void). The specialization constant will be constructed from the
286+ data in the buffer
287+
288+ - ``SpecializationConstant(spec_id, dtype, obj)``
289+ If the constructor is invoked with two variadic arguments, and the first
290+ argument is a string, it is interpreted as a NumPy ``dtype`` string and the
291+ second argument will be coerced to a NumPy array with that data type.
292+ The data type specified by the first argument must be a supported data
293+ type (bool, integral, floating point, or void).
294+
295+ - ``SpecializationConstant(spec_id, nbytes, raw_ptr)``
296+ If the constructor is invoked with two variadic arguments where both are
297+ integers, the first argument is interpreted as the number of bytes and
298+ the second argument is interpreted as a pointer to the data.
299+
300+ Note that when constructing from a buffer, the
301+ :class:`.SpecializationConstant`, shares memory with the original object.
302+ Modifications to the original object's data after construction will be
303+ reflected when the :class:`.SpecializationConstant` is used to create a
304+ :class:`.SyclKernelBundle`. This is not the case when constructing from a
305+ raw pointer, as the data is copied.
306+
307+ Args:
308+ spec_id (int):
309+ The SPIR-V specialization ID.
310+ args:
311+ Variadic argument, see class documentation.
312+
313+ Raises:
314+ TypeError: In case of incorrect arguments given to constructor,
315+ failure to coerce to a buffer, or unsupported data type when
316+ coercing to a buffer.
317+ ValueError: If the provided object fails to construct a buffer.
318+ """
319+
320+ cdef _spec_const _spec_const
321+ cdef Py_buffer _buffer
322+
323+ def __cinit__ (self , spec_id , *args ):
324+ cdef int ret_code = 0
325+ cdef object target_obj = None
326+
327+ if not isinstance (spec_id, numbers.Integral):
328+ raise TypeError (
329+ " Specialization constant ID must be of type `int`, got "
330+ f" {type(spec_id)}"
331+ )
332+
333+ if len (args) == 0 or len (args) > 2 :
334+ raise TypeError (
335+ f" Constructor takes 2 or 3 arguments, got {len(args)}."
336+ )
337+
338+ self ._spec_const.id = < uint32_t> spec_id
339+
340+ if len (args) == 2 :
341+ if (
342+ isinstance (args[0 ], numbers.Integral) and
343+ isinstance (args[1 ], numbers.Integral)
344+ ):
345+ target_obj = PyBytes_FromStringAndSize(
346+ < const char * >< size_t> args[1 ], < Py_ssize_t> args[0 ]
347+ )
348+ elif isinstance (args[0 ], str ):
349+ target_obj = np.ascontiguousarray(args[1 ], dtype = args[0 ])
350+
351+ elif len (args) == 1 :
352+ target_obj = args[0 ]
353+ if not PyObject_CheckBuffer(target_obj):
354+ # attempt to coerce to a numpy array
355+ target_obj = np.ascontiguousarray(target_obj)
356+ else :
357+ raise TypeError (
358+ " Invalid arguments."
359+ )
360+
361+ if isinstance (target_obj, np.ndarray):
362+ if target_obj.dtype.kind not in (" b" , " i" , " u" , " f" , " c" , " V" ):
363+ raise TypeError (
364+ " Coercion of input to buffer resulted in an unsupported "
365+ f" data type '{target_obj.dtype}'. When coercing objects, "
366+ " `SpecializationConstant` expects the data to coerce to a "
367+ " supported type: bool, integral, real or complex floating "
368+ " point, or void. To pass arbitrary data, use a "
369+ " `memoryview` or `bytes` object, or pass the pointer and "
370+ " size directly."
371+ )
372+
373+ ret_code = PyObject_GetBuffer(
374+ target_obj, & (self ._buffer), PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS
375+ )
376+ if ret_code != 0 :
377+ raise ValueError (
378+ " Failed to get buffer view for the provided object."
379+ )
380+ self ._spec_const.value = < void * > self ._buffer.buf
381+ self ._spec_const.size = < size_t> self ._buffer.len
382+
383+ def __dealloc__ (self ):
384+ PyBuffer_Release(& (self ._buffer))
385+
386+ def __repr__ (self ):
387+ return f" SpecializationConstant({self._spec_const.id})"
388+
389+ def __eq__ (self , other ):
390+ if not isinstance (other, SpecializationConstant):
391+ return False
392+ cdef SpecializationConstant _other = < SpecializationConstant> other
393+ if (
394+ self ._spec_const.id != _other._spec_const.id or
395+ self ._spec_const.size != _other._spec_const.size or
396+ self ._spec_const.value != _other._spec_const.value
397+ ):
398+ return False
399+ return memcmp(
400+ self ._spec_const.value,
401+ _other._spec_const.value,
402+ self ._spec_const.size
403+ ) == 0
404+
405+ @property
406+ def id (self ):
407+ """ Returns the specialization ID for this specialization constant."""
408+ return self ._spec_const.id
409+
410+ @property
411+ def size (self ):
412+ """
413+ Returns the size in bytes of the data for this specialization constant.
414+ """
415+ return self ._spec_const.size
416+
417+ cdef size_t addressof(self ):
418+ """
419+ Returns the address of the _spec_const for this
420+ :class:`.SpecializationConstant` cast to ``size_t``.
421+ """
422+ return < size_t> & (self ._spec_const)
423+
424+
255425cpdef create_kernel_bundle_from_source(SyclQueue q, str src, str copts = " " ):
256426 """
257427 Creates a Sycl interoperability kernel bundle from an OpenCL source
0 commit comments