Skip to content

Commit 2493f89

Browse files
committed
Implement SpecializationConstant class
1 parent 1a0a910 commit 2493f89

5 files changed

Lines changed: 184 additions & 1 deletion

File tree

dpctl/_backend.pxd

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,10 @@ cdef extern from "syclinterface/dpctl_sycl_context_interface.h":
431431

432432

433433
cdef extern from "syclinterface/dpctl_sycl_kernel_bundle_interface.h":
434+
ctypedef struct _spec_const "DPCTLSpecConst":
435+
uint32_t id
436+
size_t size
437+
const void *value
434438
cdef DPCTLSyclKernelBundleRef DPCTLKernelBundle_CreateFromSpirv(
435439
const DPCTLSyclContextRef Ctx,
436440
const DPCTLSyclDeviceRef Dev,

dpctl/_sycl_platform.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ cdef class SyclPlatform(_SyclPlatform):
236236
and filter string for each device is printed.
237237
238238
Args:
239-
verbosity (Literal[0, 1, 2], optional):.
239+
verbosity (Literal[0, 1, 2], optional):
240240
The verbosity controls how much information is printed by the
241241
function. Value ``0`` is the lowest level set by default and
242242
``2`` is the highest level to print the most verbose output.

dpctl/program/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"""
2323

2424
from ._program import (
25+
SpecializationConstant,
2526
SyclKernel,
2627
SyclKernelBundle,
2728
SyclKernelBundleCompilationError,
@@ -41,6 +42,7 @@
4142
"SyclKernelBundleCompilationError",
4243
"SyclProgram",
4344
"SyclProgramCompilationError",
45+
"SpecializationConstant",
4446
]
4547

4648

dpctl/program/_program.pyx

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,17 @@ an OpenCL source string or a SPIR-V binary file.
2626
2727
"""
2828

29+
from cpython.buffer cimport (
30+
Py_buffer,
31+
PyBUF_ANY_CONTIGUOUS,
32+
PyBUF_SIMPLE,
33+
PyBuffer_Release,
34+
PyObject_CheckBuffer,
35+
PyObject_GetBuffer,
36+
)
37+
from cpython.bytes cimport PyBytes_FromStringAndSize
2938
from libc.stdint cimport uint32_t
39+
from libc.string cimport memcmp
3040

3141
import warnings
3242

@@ -51,14 +61,20 @@ from dpctl._backend cimport ( # noqa: E211, E402;
5161
DPCTLSyclDeviceRef,
5262
DPCTLSyclKernelBundleRef,
5363
DPCTLSyclKernelRef,
64+
_spec_const,
5465
)
5566

67+
import numbers
68+
69+
import numpy as np
70+
5671
__all__ = [
5772
"create_kernel_bundle_from_source",
5873
"create_kernel_bundle_from_spirv",
5974
"SyclKernel",
6075
"SyclKernelBundle",
6176
"SyclKernelBundleCompilationError",
77+
"SpecializationConstant",
6278
]
6379

6480
cdef class SyclKernelBundleCompilationError(Exception):
@@ -252,6 +268,160 @@ cdef api SyclKernelBundle SyclKernelBundle_Make(DPCTLSyclKernelBundleRef KBRef):
252268
return SyclKernelBundle._create(copied_KBRef)
253269

254270

271+
cdef class SpecializationConstant:
272+
"""
273+
SpecializationConstant(spec_id, *args)
274+
275+
Python class representing SYCL specialization constants that can be used
276+
when creating a :class:`dpctl.program.SyclKernelBundle` from SPIR-V.
277+
278+
There are multiple ways to create a :class:`.SpecializationConstant`:
279+
280+
- ``SpecializationConstant(spec_id, obj)``
281+
If the constructor is invoked with a single variadic argument, the
282+
argument is expected to either expose the Python buffer protocol or be
283+
coercible to a NumPy array. If the argument is coercible to a NumPy array
284+
or is one, it must have a supported data type (bool, integral, floating
285+
point, or void). The specialization constant will be constructed from the
286+
data in the buffer
287+
288+
- ``SpecializationConstant(spec_id, dtype, obj)``
289+
If the constructor is invoked with two variadic arguments, and the first
290+
argument is a string, it is interpreted as a NumPy ``dtype`` string and the
291+
second argument will be coerced to a NumPy array with that data type.
292+
The data type specified by the first argument must be a supported data
293+
type (bool, integral, floating point, or void).
294+
295+
- ``SpecializationConstant(spec_id, nbytes, raw_ptr)``
296+
If the constructor is invoked with two variadic arguments where both are
297+
integers, the first argument is interpreted as the number of bytes and
298+
the second argument is interpreted as a pointer to the data.
299+
300+
Note that when constructing from a buffer, the
301+
:class:`.SpecializationConstant`, shares memory with the original object.
302+
Modifications to the original object's data after construction will be
303+
reflected when the :class:`.SpecializationConstant` is used to create a
304+
:class:`.SyclKernelBundle`. This is not the case when constructing from a
305+
raw pointer, as the data is copied.
306+
307+
Args:
308+
spec_id (int):
309+
The SPIR-V specialization ID.
310+
args:
311+
Variadic argument, see class documentation.
312+
313+
Raises:
314+
TypeError: In case of incorrect arguments given to constructor,
315+
failure to coerce to a buffer, or unsupported data type when
316+
coercing to a buffer.
317+
ValueError: If the provided object fails to construct a buffer.
318+
"""
319+
320+
cdef _spec_const _spec_const
321+
cdef Py_buffer _buffer
322+
323+
def __cinit__(self, spec_id, *args):
324+
cdef int ret_code = 0
325+
cdef object target_obj = None
326+
327+
if not isinstance(spec_id, numbers.Integral):
328+
raise TypeError(
329+
"Specialization constant ID must be of type `int`, got "
330+
f"{type(spec_id)}"
331+
)
332+
333+
if len(args) == 0 or len(args) > 2:
334+
raise TypeError(
335+
f"Constructor takes 2 or 3 arguments, got {len(args)}."
336+
)
337+
338+
self._spec_const.id = <uint32_t>spec_id
339+
340+
if len(args) == 2:
341+
if (
342+
isinstance(args[0], numbers.Integral) and
343+
isinstance(args[1], numbers.Integral)
344+
):
345+
target_obj = PyBytes_FromStringAndSize(
346+
<const char *><size_t>args[1], <Py_ssize_t>args[0]
347+
)
348+
elif isinstance(args[0], str):
349+
target_obj = np.ascontiguousarray(args[1], dtype=args[0])
350+
351+
elif len(args) == 1:
352+
target_obj = args[0]
353+
if not PyObject_CheckBuffer(target_obj):
354+
# attempt to coerce to a numpy array
355+
target_obj = np.ascontiguousarray(target_obj)
356+
else:
357+
raise TypeError(
358+
"Invalid arguments."
359+
)
360+
361+
if isinstance(target_obj, np.ndarray):
362+
if target_obj.dtype.kind not in ("b", "i", "u", "f", "c", "V"):
363+
raise TypeError(
364+
"Coercion of input to buffer resulted in an unsupported "
365+
f"data type '{target_obj.dtype}'. When coercing objects, "
366+
"`SpecializationConstant` expects the data to coerce to a "
367+
"supported type: bool, integral, real or complex floating "
368+
"point, or void. To pass arbitrary data, use a "
369+
"`memoryview` or `bytes` object, or pass the pointer and "
370+
"size directly."
371+
)
372+
373+
ret_code = PyObject_GetBuffer(
374+
target_obj, &(self._buffer), PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS
375+
)
376+
if ret_code != 0:
377+
raise ValueError(
378+
"Failed to get buffer view for the provided object."
379+
)
380+
self._spec_const.value = <void*>self._buffer.buf
381+
self._spec_const.size = <size_t>self._buffer.len
382+
383+
def __dealloc__(self):
384+
PyBuffer_Release(&(self._buffer))
385+
386+
def __repr__(self):
387+
return f"SpecializationConstant({self._spec_const.id})"
388+
389+
def __eq__(self, other):
390+
if not isinstance(other, SpecializationConstant):
391+
return False
392+
cdef SpecializationConstant _other = <SpecializationConstant>other
393+
if (
394+
self._spec_const.id != _other._spec_const.id or
395+
self._spec_const.size != _other._spec_const.size or
396+
self._spec_const.value != _other._spec_const.value
397+
):
398+
return False
399+
return memcmp(
400+
self._spec_const.value,
401+
_other._spec_const.value,
402+
self._spec_const.size
403+
) == 0
404+
405+
@property
406+
def id(self):
407+
"""Returns the specialization ID for this specialization constant."""
408+
return self._spec_const.id
409+
410+
@property
411+
def size(self):
412+
"""
413+
Returns the size in bytes of the data for this specialization constant.
414+
"""
415+
return self._spec_const.size
416+
417+
cdef size_t addressof(self):
418+
"""
419+
Returns the address of the _spec_const for this
420+
:class:`.SpecializationConstant` cast to ``size_t``.
421+
"""
422+
return <size_t>&(self._spec_const)
423+
424+
255425
cpdef create_kernel_bundle_from_source(SyclQueue q, str src, str copts=""):
256426
"""
257427
Creates a Sycl interoperability kernel bundle from an OpenCL source

libsyclinterface/include/syclinterface/dpctl_sycl_kernel_bundle_interface.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@
3535

3636
DPCTL_C_EXTERN_C_BEGIN
3737

38+
typedef struct DPCTLSpecConstTy
39+
{
40+
uint32_t id;
41+
size_t size;
42+
const void *value;
43+
} DPCTLSpecConst;
44+
3845
/**
3946
* @defgroup KernelBundleInterface Kernel_bundle class C wrapper
4047
*/

0 commit comments

Comments
 (0)