Skip to content

Commit d8b0afb

Browse files
committed
add test for composite specialization constant
also removes "v" as a permitted specialization constant intermediate data type, as composite specialization constants are broken into multiple specialization constants, so structs end up passed as a single constant while the program expects multiple, and therefore, doesn't work as intended
1 parent aba560b commit d8b0afb

3 files changed

Lines changed: 47 additions & 6 deletions

File tree

dpctl/program/_program.pyx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,16 +282,16 @@ cdef class SpecializationConstant:
282282
If the constructor is invoked with a single variadic argument, the
283283
argument is expected to either expose the Python buffer protocol or be
284284
coercible to a NumPy array. If the argument is coercible to a NumPy array
285-
or is one, it must have a supported data type (bool, integral, floating
286-
point, or void). The specialization constant will be constructed from the
285+
or is one, it must have a supported data type (bool, integral, or
286+
floating point). The specialization constant will be constructed from the
287287
data in the buffer
288288
289289
- ``SpecializationConstant(spec_id, dtype, obj)``
290290
If the constructor is invoked with two variadic arguments, and the first
291291
argument is a string, it is interpreted as a NumPy ``dtype`` string and the
292292
second argument will be coerced to a NumPy array with that data type.
293293
The data type specified by the first argument must be a supported data
294-
type (bool, integral, floating point, or void).
294+
type (bool, integral, or floating point).
295295
296296
- ``SpecializationConstant(spec_id, nbytes, raw_ptr)``
297297
If the constructor is invoked with two variadic arguments where both are
@@ -360,13 +360,13 @@ cdef class SpecializationConstant:
360360
)
361361

362362
if isinstance(target_obj, np.ndarray):
363-
if target_obj.dtype.kind not in ("b", "i", "u", "f", "c", "V"):
363+
if target_obj.dtype.kind not in ("b", "i", "u", "f", "c"):
364364
raise TypeError(
365365
"Coercion of input to buffer resulted in an unsupported "
366366
f"data type '{target_obj.dtype}'. When coercing objects, "
367367
"`SpecializationConstant` expects the data to coerce to a "
368-
"supported type: bool, integral, real or complex floating "
369-
"point, or void. To pass arbitrary data, use a "
368+
"supported type: bool, integral, or real or complex "
369+
"floating point. To pass arbitrary data, use a "
370370
"`memoryview` or `bytes` object, or pass the pointer and "
371371
"size directly."
372372
)
3.35 KB
Binary file not shown.

dpctl/tests/test_sycl_program.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,44 @@ def test_create_kernel_bundle_with_spec_const():
300300
ht_e.wait()
301301

302302
assert np.all(y == 43)
303+
304+
305+
def test_create_kernel_bundle_with_composite_spec_const():
306+
try:
307+
q = dpctl.SyclQueue()
308+
except dpctl.SyclQueueCreationError:
309+
pytest.skip("Could not create default queue")
310+
311+
# composite specialization constants are separated into individual
312+
# specialization constants with unique spec_ids
313+
sp1 = dpctl_prog.SpecializationConstant(0, "i4", 10)
314+
sp2 = dpctl_prog.SpecializationConstant(1, "f4", 2.5)
315+
sp3 = dpctl_prog.SpecializationConstant(2, "?", 1)
316+
317+
spirv_file = get_spirv_abspath("specialization_constant_composite.spv")
318+
with open(spirv_file, "br") as spv:
319+
spv_bytes = spv.read()
320+
321+
kb = dpctl_prog.create_kernel_bundle_from_spirv(
322+
q, spv_bytes, specializations=[sp1, sp2, sp3]
323+
)
324+
kernel = kb.get_sycl_kernel("_ZTS21StructSpecConstKernel")
325+
326+
n = 128
327+
x = np.ones(n, dtype="f4")
328+
y = np.zeros_like(x)
329+
330+
x_usm = dpctl.memory.MemoryUSMDevice(x.nbytes, queue=q)
331+
y_usm = dpctl.memory.MemoryUSMDevice(y.nbytes, queue=q)
332+
333+
e1 = q.memcpy_async(x_usm, x, x.nbytes)
334+
e2 = q.submit(kernel, [x_usm, y_usm], [n], dEvents=[e1])
335+
e3 = q.memcpy_async(y, y_usm, y.nbytes, [e2])
336+
337+
ht_e = q._submit_keep_args_alive([x_usm], [e3])
338+
339+
e3.wait()
340+
ht_e.wait()
341+
342+
# 1.0 * 10 + 2.5 = 12.5
343+
assert np.all(y == 12.5)

0 commit comments

Comments
 (0)