Skip to content

Commit 4cafebf

Browse files
authored
CTS for linear_indexing language feature (#4595)
* Execution tests for new builtins * Validation tests for new builtins * Api dispatch validation tests
1 parent 67c7a2f commit 4cafebf

5 files changed

Lines changed: 219 additions & 10 deletions

File tree

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
export const description = `
2+
Compute dispatch validation tests.
3+
`;
4+
5+
import { AllFeaturesMaxLimitsGPUTest } from '../.././gpu_test.js';
6+
import { makeTestGroup } from '../../../common/framework/test_group.js';
7+
8+
export const g = makeTestGroup(AllFeaturesMaxLimitsGPUTest);
9+
10+
g.test('dispatch,linear_indexing_range')
11+
.desc('Tests validation of total invocations for linear_indexing built-in values')
12+
.params(u =>
13+
u
14+
.combine('builtin', ['global_invocation_index', 'workgroup_index'] as const)
15+
.beginSubcases()
16+
.combine('size', ['max', 'valid'] as const)
17+
)
18+
.fn(t => {
19+
// Other builtins are not tested due to onerous runtimes.
20+
t.skipIf(!t.hasLanguageFeature('linear_indexing'), 'Missing linear_indexing language feature');
21+
22+
// Spec limits:
23+
// - maxComputeWorkgroupsPerDimension = 65535
24+
const { maxComputeWorkgroupsPerDimension } = t.device.limits;
25+
const x = t.params.builtin === 'global_invocation_index' ? 2 : 1,
26+
y = 1,
27+
z = 1;
28+
const wgSize = x * y * z;
29+
const countX = maxComputeWorkgroupsPerDimension;
30+
const countY = t.params.size === 'max' ? maxComputeWorkgroupsPerDimension : 1;
31+
const countZ = t.params.builtin === 'workgroup_index' ? 2 : 1;
32+
33+
const totalInvocations = wgSize * countX * countY * countZ;
34+
t.skipIf(t.params.size === 'max' && totalInvocations <= 0xffffffff, 'Uninteresting test');
35+
36+
const code = `
37+
@compute @workgroup_size(${x}, ${y}, ${z})
38+
fn main(@builtin(${t.params.builtin}) input : u32) {
39+
_ = input;
40+
}`;
41+
42+
const shaderModule = t.device.createShaderModule({ code });
43+
const computePipeline = t.device.createComputePipeline({
44+
layout: 'auto',
45+
compute: {
46+
module: shaderModule,
47+
},
48+
});
49+
const commandEncoder = t.device.createCommandEncoder();
50+
const computePassEncoder = commandEncoder.beginComputePass();
51+
computePassEncoder.setPipeline(computePipeline);
52+
computePassEncoder.dispatchWorkgroups(countX, countY, countZ);
53+
computePassEncoder.end();
54+
55+
t.expectValidationError(() => {
56+
commandEncoder.finish();
57+
}, t.params.size === 'max');
58+
});
59+
60+
g.test('dispatchIndirect,linear_indexing_range')
61+
.desc('Tests dispatchIndirect skips when linear_indexing is out of range')
62+
.params(u =>
63+
u
64+
.combine('builtin', ['global_invocation_index', 'workgroup_index'] as const)
65+
.beginSubcases()
66+
.combine('size', ['max', 'valid'] as const)
67+
)
68+
.fn(t => {
69+
// Other builtins are not tested due to onerous runtimes.
70+
t.skipIf(!t.hasLanguageFeature('linear_indexing'), 'Missing linear_indexing language feature');
71+
72+
// Spec limits:
73+
// - maxComputeWorkgroupsPerDimension = 65535
74+
const { maxComputeWorkgroupsPerDimension } = t.device.limits;
75+
const x = t.params.builtin === 'global_invocation_index' ? 2 : 1,
76+
y = 1,
77+
z = 1;
78+
const wgSize = x * y * z;
79+
const countX = maxComputeWorkgroupsPerDimension;
80+
const countY = t.params.size === 'max' ? maxComputeWorkgroupsPerDimension : 1;
81+
const countZ = t.params.builtin === 'workgroup_index' ? 2 : 1;
82+
83+
const totalInvocations = wgSize * countX * countY * countZ;
84+
t.skipIf(t.params.size === 'max' && totalInvocations <= 0xffffffff, 'Uninteresting test');
85+
86+
const kMagic = 0xdeadbeef;
87+
const code = `
88+
@group(0) @binding(0)
89+
var<storage, read_write> out : u32;
90+
91+
@compute @workgroup_size(${x}, ${y}, ${z})
92+
fn main(@builtin(${t.params.builtin}) input : u32,
93+
@builtin(global_invocation_id) gid : vec3u) {
94+
_ = input;
95+
if (gid.x == 0 && gid.y == 0 && gid.z == 0) {
96+
out = ${kMagic};
97+
}
98+
}`;
99+
100+
const dispatchIndirectCounts = new Uint32Array(3);
101+
dispatchIndirectCounts[0] = countX;
102+
dispatchIndirectCounts[1] = countY;
103+
dispatchIndirectCounts[2] = countZ;
104+
const indirectBuffer = t.makeBufferWithContents(
105+
dispatchIndirectCounts,
106+
GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.INDIRECT
107+
);
108+
t.trackForCleanup(indirectBuffer);
109+
const outputBuffer = t.makeBufferWithContents(
110+
new Uint32Array([0]),
111+
GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE
112+
);
113+
t.trackForCleanup(outputBuffer);
114+
115+
const shaderModule = t.device.createShaderModule({ code });
116+
const computePipeline = t.device.createComputePipeline({
117+
layout: 'auto',
118+
compute: {
119+
module: shaderModule,
120+
},
121+
});
122+
const bg = t.device.createBindGroup({
123+
layout: computePipeline.getBindGroupLayout(0),
124+
entries: [
125+
{
126+
binding: 0,
127+
resource: {
128+
buffer: outputBuffer,
129+
},
130+
},
131+
],
132+
});
133+
const commandEncoder = t.device.createCommandEncoder();
134+
const computePassEncoder = commandEncoder.beginComputePass();
135+
computePassEncoder.setPipeline(computePipeline);
136+
computePassEncoder.setBindGroup(0, bg);
137+
computePassEncoder.dispatchWorkgroupsIndirect(indirectBuffer, 0);
138+
computePassEncoder.end();
139+
t.queue.submit([commandEncoder.finish()]);
140+
141+
const expected = t.params.size === 'max' ? 0 : kMagic;
142+
t.expectGPUBufferValuesEqual(outputBuffer, new Uint32Array([expected]));
143+
});

src/webgpu/capability_info.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,7 @@ export const kKnownWGSLLanguageFeatures = [
981981
'subgroup_id',
982982
'subgroup_uniformity',
983983
'swizzle_assignment',
984+
'linear_indexing',
984985
] as const;
985986

986987
export type WGSLLanguageFeature = (typeof kKnownWGSLLanguageFeatures)[number];

src/webgpu/listing_meta.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,8 @@
509509
"webgpu:api,validation,createView:texture_view_usage_with_view_format:*": { "subcaseMS": 2406.440 },
510510
"webgpu:api,validation,debugMarker:push_pop_call_count_unbalance,command_encoder:*": { "subcaseMS": 1.522 },
511511
"webgpu:api,validation,debugMarker:push_pop_call_count_unbalance,render_compute_pass:*": { "subcaseMS": 0.601 },
512+
"webgpu:api,validation,dispatch:dispatch,linear_indexing_range:*": { "subcaseMS": 359.656 },
513+
"webgpu:api,validation,dispatch:dispatchIndirect,linear_indexing_range:*": { "subcaseMS": 320.426 },
512514
"webgpu:api,validation,encoding,beginComputePass:timestampWrites,invalid_query_set:*": { "subcaseMS": 0.201 },
513515
"webgpu:api,validation,encoding,beginComputePass:timestampWrites,query_index:*": { "subcaseMS": 0.201 },
514516
"webgpu:api,validation,encoding,beginComputePass:timestampWrites,query_set_type:*": { "subcaseMS": 0.401 },

src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ g.test('inputs')
3434
.beginSubcases()
3535
)
3636
.fn(t => {
37+
const linear_indexing = t.hasLanguageFeature('linear_indexing');
3738
const invocationsPerGroup = t.params.groupSize.x * t.params.groupSize.y * t.params.groupSize.z;
3839
const totalInvocations =
3940
invocationsPerGroup * t.params.numGroups.x * t.params.numGroups.y * t.params.numGroups.z;
@@ -46,6 +47,8 @@ g.test('inputs')
4647
let global_id = '';
4748
let group_id = '';
4849
let num_groups = '';
50+
let global_index = '';
51+
let group_index = '';
4952
switch (t.params.method) {
5053
case 'param':
5154
params = `
@@ -54,12 +57,18 @@ g.test('inputs')
5457
@builtin(global_invocation_id) global_id : vec3<u32>,
5558
@builtin(workgroup_id) group_id : vec3<u32>,
5659
@builtin(num_workgroups) num_groups : vec3<u32>,
60+
${linear_indexing ? '@builtin(global_invocation_index) global_index : u32,' : ''}
61+
${linear_indexing ? '@builtin(workgroup_index) group_index : u32,' : ''}
5762
`;
5863
local_id = 'local_id';
5964
local_index = 'local_index';
6065
global_id = 'global_id';
6166
group_id = 'group_id';
6267
num_groups = 'num_groups';
68+
if (linear_indexing) {
69+
global_index = 'global_index';
70+
group_index = 'group_index';
71+
}
6372
break;
6473
case 'struct':
6574
structures = `struct Inputs {
@@ -68,13 +77,19 @@ g.test('inputs')
6877
@builtin(global_invocation_id) global_id : vec3<u32>,
6978
@builtin(workgroup_id) group_id : vec3<u32>,
7079
@builtin(num_workgroups) num_groups : vec3<u32>,
80+
${linear_indexing ? '@builtin(global_invocation_index) global_index : u32,' : ''}
81+
${linear_indexing ? '@builtin(workgroup_index) group_index : u32,' : ''}
7182
};`;
7283
params = `inputs : Inputs`;
7384
local_id = 'inputs.local_id';
7485
local_index = 'inputs.local_index';
7586
global_id = 'inputs.global_id';
7687
group_id = 'inputs.group_id';
7788
num_groups = 'inputs.num_groups';
89+
if (linear_indexing) {
90+
global_index = 'inputs.global_index';
91+
group_index = 'inputs.group_index';
92+
}
7893
break;
7994
case 'mixed':
8095
structures = `struct InputsA {
@@ -87,12 +102,19 @@ g.test('inputs')
87102
params = `@builtin(local_invocation_id) local_id : vec3<u32>,
88103
inputsA : InputsA,
89104
inputsB : InputsB,
90-
@builtin(num_workgroups) num_groups : vec3<u32>,`;
105+
@builtin(num_workgroups) num_groups : vec3<u32>,
106+
${linear_indexing ? '@builtin(global_invocation_index) global_index : u32,' : ''}
107+
${linear_indexing ? '@builtin(workgroup_index) group_index : u32,' : ''}
108+
`;
91109
local_id = 'local_id';
92110
local_index = 'inputsA.local_index';
93111
global_id = 'inputsA.global_id';
94112
group_id = 'inputsB.group_id';
95113
num_groups = 'num_groups';
114+
if (linear_indexing) {
115+
global_index = 'global_index';
116+
group_index = 'group_index';
117+
}
96118
break;
97119
}
98120

@@ -104,6 +126,8 @@ g.test('inputs')
104126
global_id: vec3u,
105127
group_id: vec3u,
106128
num_groups: vec3u,
129+
${linear_indexing ? 'global_index : u32,' : ''}
130+
${linear_indexing ? 'group_index : u32,' : ''}
107131
};
108132
@group(0) @binding(0) var<storage, read_write> outputs : array<Outputs>;
109133
@@ -117,15 +141,17 @@ g.test('inputs')
117141
fn main(
118142
${params}
119143
) {
120-
let group_index = ((${group_id}.z * ${num_groups}.y) + ${group_id}.y) * ${num_groups}.x + ${group_id}.x;
121-
let global_index = group_index * ${invocationsPerGroup}u + ${local_index};
144+
let o_group_index = ((${group_id}.z * ${num_groups}.y) + ${group_id}.y) * ${num_groups}.x + ${group_id}.x;
145+
let o_global_index = o_group_index * ${invocationsPerGroup}u + ${local_index};
122146
var o: Outputs;
123147
o.local_id = ${local_id};
124148
o.local_index = ${local_index};
125149
o.global_id = ${global_id};
126150
o.group_id = ${group_id};
127151
o.num_groups = ${num_groups};
128-
outputs[global_index] = o;
152+
${linear_indexing ? `o.global_index = ${global_index};` : ``}
153+
${linear_indexing ? `o.group_index = ${group_index};` : ``}
154+
outputs[o_global_index] = o;
129155
}
130156
`;
131157

@@ -145,7 +171,9 @@ g.test('inputs')
145171
const kGlobalIdOffset = 4;
146172
const kGroupIdOffset = 8;
147173
const kNumGroupsOffset = 12;
148-
const kOutputElementSize = 16;
174+
const kGlobalIndexOffset = 15;
175+
const kGroupIndexOffset = 16;
176+
const kOutputElementSize = linear_indexing ? 20 : 16;
149177

150178
// Create the output buffers.
151179
const outputBuffer = t.createBufferTracked({
@@ -203,6 +231,21 @@ g.test('inputs')
203231
const localIndex = (lz * t.params.groupSize.y + ly) * t.params.groupSize.x + lx;
204232
const globalIndex = groupIndex * invocationsPerGroup + localIndex;
205233
const globalOffset = globalIndex * kOutputElementSize;
234+
const gidX = gx * t.params.groupSize.x + lx;
235+
const gidY = gy * t.params.groupSize.y + ly;
236+
const gidZ = gz * t.params.groupSize.z + lz;
237+
const globalLinearIndex =
238+
gidX +
239+
gidY * t.params.groupSize.x * t.params.numGroups.x +
240+
gidZ *
241+
t.params.groupSize.x *
242+
t.params.numGroups.x *
243+
t.params.groupSize.y *
244+
t.params.numGroups.y;
245+
const groupLinearIndex =
246+
gx +
247+
gy * t.params.numGroups.x +
248+
gz * t.params.numGroups.x * t.params.numGroups.y;
206249

207250
const expectEqual = (name: string, expected: number, actual: number) => {
208251
if (actual !== expected) {
@@ -226,17 +269,23 @@ g.test('inputs')
226269

227270
const error =
228271
checkVec3Value('local_id', kLocalIdOffset, { x: lx, y: ly, z: lz }) ||
229-
checkVec3Value('global_id', kGlobalIdOffset, {
230-
x: gx * t.params.groupSize.x + lx,
231-
y: gy * t.params.groupSize.y + ly,
232-
z: gz * t.params.groupSize.z + lz,
233-
}) ||
272+
checkVec3Value('global_id', kGlobalIdOffset, { x: gidX, y: gidY, z: gidZ }) ||
234273
checkVec3Value('group_id', kGroupIdOffset, { x: gx, y: gy, z: gz }) ||
235274
checkVec3Value('num_groups', kNumGroupsOffset, t.params.numGroups) ||
236275
expectEqual(
237276
'local_index',
238277
localIndex,
239278
output[globalOffset + kLocalIndexOffset]
279+
) ||
280+
expectEqual(
281+
'global_index',
282+
globalLinearIndex,
283+
output[globalOffset + kGlobalIndexOffset]
284+
) ||
285+
expectEqual(
286+
'group_index',
287+
groupLinearIndex,
288+
output[globalOffset + kGroupIndexOffset]
240289
);
241290
if (error) {
242291
return error;

src/webgpu/shader/validation/shader_io/builtins.spec.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,20 @@ export const kBuiltins: readonly Builtin[] = [
119119
enable: 'subgroups',
120120
requires: 'subgroup_id',
121121
},
122+
{
123+
name: 'workgroup_index',
124+
stage: 'compute',
125+
io: 'in',
126+
type: 'u32',
127+
requires: 'linear_indexing',
128+
},
129+
{
130+
name: 'global_invocation_index',
131+
stage: 'compute',
132+
io: 'in',
133+
type: 'u32',
134+
requires: 'linear_indexing',
135+
},
122136
] as const;
123137

124138
// List of types to test against.

0 commit comments

Comments
 (0)