Skip to content

Commit cea97a8

Browse files
authored
Add tests for override-sized array pointer parameters (#4632)
Fixes #4629 * Validation and execution tests
1 parent 9211907 commit cea97a8

2 files changed

Lines changed: 163 additions & 16 deletions

File tree

src/webgpu/shader/execution/expression/call/user/ptr_params.spec.ts

Lines changed: 128 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,34 @@ import { GPUTest } from '../../../../../gpu_test.js';
77

88
export const g = makeTestGroup(GPUTest);
99

10-
function wgslTypeDecl(kind: 'vec4i' | 'array' | 'struct') {
10+
function wgslTypeDecl(
11+
kind: 'vec4i' | 'array' | 'override_array1' | 'override_array2' | 'override_array3' | 'struct'
12+
) {
1113
switch (kind) {
1214
case 'vec4i':
1315
return `
1416
alias T = vec4i;
17+
alias RT = T;
1518
`;
1619
case 'array':
1720
return `
1821
alias T = array<vec4f, 3>;
22+
alias RT = T;
23+
`;
24+
case 'override_array1':
25+
return `
26+
alias T = array<vec4f, over_no_default>;
27+
alias RT = array<vec4f, 3>;
28+
`;
29+
case 'override_array2':
30+
return `
31+
alias T = array<vec4f, over_default>;
32+
alias RT = array<vec4f, 3>;
33+
`;
34+
case 'override_array3':
35+
return `
36+
alias T = array<vec4f, over_expr>;
37+
alias RT = array<vec4f, 3>;
1938
`;
2039
case 'struct':
2140
return `
@@ -26,15 +45,21 @@ c : i32,
2645
d : u32,
2746
}
2847
alias T = S;
48+
alias RT = T;
2949
`;
3050
}
3151
}
3252

33-
function valuesForType(kind: 'vec4i' | 'array' | 'struct') {
53+
function valuesForType(
54+
kind: 'vec4i' | 'array' | 'override_array1' | 'override_array2' | 'override_array3' | 'struct'
55+
) {
3456
switch (kind) {
3557
case 'vec4i':
3658
return new Uint32Array([1, 2, 3, 4]);
3759
case 'array':
60+
case 'override_array1':
61+
case 'override_array2':
62+
case 'override_array3':
3863
return new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
3964
case 'struct':
4065
return new Uint32Array([1, 2, 3, 4]);
@@ -46,13 +71,15 @@ function run(
4671
wgsl: string,
4772
inputUsage: 'uniform' | 'storage',
4873
input: Uint32Array | Float32Array,
49-
expected: Uint32Array | Float32Array
74+
expected: Uint32Array | Float32Array,
75+
constants: Record<string, number> = {}
5076
) {
5177
const pipeline = t.device.createComputePipeline({
5278
layout: 'auto',
5379
compute: {
5480
module: t.device.createShaderModule({ code: wgsl }),
5581
entryPoint: 'main',
82+
constants,
5683
},
5784
});
5885

@@ -91,7 +118,24 @@ g.test('read_full_object')
91118
u
92119
.combine('address_space', ['function', 'private', 'workgroup', 'storage', 'uniform'] as const)
93120
.combine('call_indirection', [0, 1, 2] as const)
94-
.combine('type', ['vec4i', 'array', 'struct'] as const)
121+
.combine('type', [
122+
'vec4i',
123+
'array',
124+
'override_array1',
125+
'override_array2',
126+
'override_array3',
127+
'struct',
128+
] as const)
129+
.filter(t => {
130+
switch (t.type) {
131+
case 'override_array1':
132+
case 'override_array2':
133+
case 'override_array3':
134+
return t.address_space === 'workgroup';
135+
default:
136+
return true;
137+
}
138+
})
95139
)
96140
.fn(t => {
97141
switch (t.params.address_space) {
@@ -101,6 +145,27 @@ g.test('read_full_object')
101145
t.skipIfLanguageFeatureNotSupported('unrestricted_pointer_parameters');
102146
}
103147

148+
let wg_assign_input = 'W = input;';
149+
let output_assign = 'output = *p;';
150+
if (t.params.address_space === 'workgroup') {
151+
switch (t.params.type) {
152+
case 'override_array1':
153+
case 'override_array2':
154+
case 'override_array3':
155+
wg_assign_input = `
156+
for (var i = 0u; i < 3; i++) {
157+
W[i] = input[i];
158+
}`;
159+
output_assign = `
160+
for (var i = 0u; i < 3; i++) {
161+
output[i] = (*p)[i];
162+
}`;
163+
break;
164+
default:
165+
break;
166+
}
167+
}
168+
104169
const main: string = {
105170
function: `
106171
@compute @workgroup_size(1)
@@ -121,7 +186,7 @@ fn main() {
121186
var<workgroup> W : T;
122187
@compute @workgroup_size(1)
123188
fn main() {
124-
W = input;
189+
${wg_assign_input}
125190
f0(&W);
126191
}
127192
`,
@@ -150,18 +215,21 @@ fn f${i}(p : ptr<${t.params.address_space}, T>) {
150215

151216
const inputVar: string =
152217
t.params.address_space === 'uniform'
153-
? `@binding(0) @group(0) var<uniform> input : T;`
154-
: `@binding(0) @group(0) var<storage, read> input : T;`;
218+
? `@binding(0) @group(0) var<uniform> input : RT;`
219+
: `@binding(0) @group(0) var<storage, read> input : RT;`;
155220

156221
const wgsl = `
222+
override over_no_default : u32;
223+
override over_default = 1u;
224+
override over_expr = over_default + over_no_default - 3u;
157225
${wgslTypeDecl(t.params.type)}
158226
159227
${inputVar}
160228
161-
@binding(1) @group(0) var<storage, read_write> output : T;
229+
@binding(1) @group(0) var<storage, read_write> output : RT;
162230
163231
fn f${t.params.call_indirection}(p : ptr<${t.params.address_space}, T>) {
164-
output = *p;
232+
${output_assign}
165233
}
166234
167235
${call_chain}
@@ -171,7 +239,10 @@ ${main}
171239

172240
const values = valuesForType(t.params.type);
173241

174-
run(t, wgsl, t.params.address_space === 'uniform' ? 'uniform' : 'storage', values, values);
242+
run(t, wgsl, t.params.address_space === 'uniform' ? 'uniform' : 'storage', values, values, {
243+
over_no_default: 3,
244+
over_default: 3,
245+
});
175246
});
176247

177248
g.test('read_ptr_to_member')
@@ -374,7 +445,24 @@ g.test('write_full_object')
374445
u
375446
.combine('address_space', ['function', 'private', 'workgroup', 'storage'] as const)
376447
.combine('call_indirection', [0, 1, 2] as const)
377-
.combine('type', ['vec4i', 'array', 'struct'] as const)
448+
.combine('type', [
449+
'vec4i',
450+
'array',
451+
'override_array1',
452+
'override_array2',
453+
'override_array3',
454+
'struct',
455+
] as const)
456+
.filter(t => {
457+
switch (t.type) {
458+
case 'override_array1':
459+
case 'override_array2':
460+
case 'override_array3':
461+
return t.address_space === 'workgroup';
462+
default:
463+
return true;
464+
}
465+
})
378466
)
379467
.fn(t => {
380468
switch (t.params.address_space) {
@@ -383,6 +471,27 @@ g.test('write_full_object')
383471
t.skipIfLanguageFeatureNotSupported('unrestricted_pointer_parameters');
384472
}
385473

474+
let wg_output_assign = 'output = W;';
475+
let assign_from_input = '*p = input;';
476+
if (t.params.address_space === 'workgroup') {
477+
switch (t.params.type) {
478+
case 'override_array1':
479+
case 'override_array2':
480+
case 'override_array3':
481+
wg_output_assign = `
482+
for (var i = 0u; i < 3; i++) {
483+
output[i] = W[i];
484+
}`;
485+
assign_from_input = `
486+
for (var i = 0u; i < 3; i++) {
487+
(*p)[i] = input[i];
488+
}`;
489+
break;
490+
default:
491+
break;
492+
}
493+
}
494+
386495
const ptr =
387496
t.params.address_space === 'storage'
388497
? `ptr<storage, T, read_write>`
@@ -410,7 +519,7 @@ var<workgroup> W : T;
410519
@compute @workgroup_size(1)
411520
fn main() {
412521
f0(&W);
413-
output = W;
522+
${wg_output_assign}
414523
}
415524
`,
416525
storage: `
@@ -431,13 +540,16 @@ fn f${i}(p : ${ptr}) {
431540
}
432541

433542
const wgsl = `
543+
override over_no_default : u32;
544+
override over_default = 1u;
545+
override over_expr = over_default + over_no_default - 3u;
434546
${wgslTypeDecl(t.params.type)}
435547
436-
@binding(0) @group(0) var<uniform> input : T;
437-
@binding(1) @group(0) var<storage, read_write> output : T;
548+
@binding(0) @group(0) var<uniform> input : RT;
549+
@binding(1) @group(0) var<storage, read_write> output : RT;
438550
439551
fn f${t.params.call_indirection}(p : ${ptr}) {
440-
*p = input;
552+
${assign_from_input}
441553
}
442554
443555
${call_chain}
@@ -447,7 +559,7 @@ ${main}
447559

448560
const values = valuesForType(t.params.type);
449561

450-
run(t, wgsl, 'uniform', values, values);
562+
run(t, wgsl, 'uniform', values, values, { over_no_default: 3, over_default: 3 });
451563
});
452564

453565
g.test('write_ptr_to_member')

src/webgpu/shader/validation/functions/restrictions.spec.ts

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ struct struct_with_array {
3434
a : array<constructible, 4>
3535
}
3636
37+
override override_no_default : u32;
38+
override override_default = 4u;
39+
override override_expr = override_default + 2;
40+
3741
`;
3842

3943
const kVertexPosCases: Record<string, VertexPosCase> = {
@@ -278,6 +282,18 @@ const kFunctionParamTypeCases: Record<string, ParamTypeCase> = {
278282
name: `ptr<workgroup, array<atomic<u32>,1>>`,
279283
valid: 'with_unrestricted_pointer_parameters',
280284
},
285+
ptrWorkgroupOverrideNoDefault: {
286+
name: `ptr<workgroup, array<u32, override_no_default>>`,
287+
valid: 'with_unrestricted_pointer_parameters',
288+
},
289+
ptrWorkgroupOverrideWithDefault: {
290+
name: `ptr<workgroup, array<f32, override_default>>`,
291+
valid: 'with_unrestricted_pointer_parameters',
292+
},
293+
ptrWorkgroupOverrideExpr: {
294+
name: `ptr<workgroup, array<vec4f, override_expr>>`,
295+
valid: 'with_unrestricted_pointer_parameters',
296+
},
281297

282298
// Invalid pointers.
283299
invalid_ptr1: { name: `ptr<handle, u32>`, valid: false }, // Can't spell handle address space
@@ -488,6 +504,21 @@ const kFunctionParamValueCases: Record<string, ParamValueCase> = {
488504
matches: ['ptr12'],
489505
needsUnrestrictedPointerParameters: true,
490506
},
507+
ptrWorkgroupOverrideNoDefault: {
508+
value: `&wg_override_no_default`,
509+
matches: ['ptrWorkgroupOverrideNoDefault'],
510+
needsUnrestrictedPointerParameters: true,
511+
},
512+
ptrWorkgroupOverrideWithDefault: {
513+
value: `&wg_override_default`,
514+
matches: ['ptrWorkgroupOverrideWithDefault'],
515+
needsUnrestrictedPointerParameters: true,
516+
},
517+
ptrWorkgroupOverrideExpr: {
518+
value: `&wg_override_expr`,
519+
matches: ['ptrWorkgroupOverrideExpr'],
520+
needsUnrestrictedPointerParameters: true,
521+
},
491522
};
492523

493524
function parameterMatches(decl: string, matches: string[]): boolean {
@@ -569,6 +600,10 @@ var<private> g_array5 : array<bool, 4>;
569600
var<private> g_constructible : constructible;
570601
var<private> g_struct_with_array : struct_with_array;
571602
603+
var<workgroup> wg_override_no_default : array<u32, override_no_default>;
604+
var<workgroup> wg_override_default : array<f32, override_default>;
605+
var<workgroup> wg_override_expr : array<vec4f, override_expr>;
606+
572607
fn foo() {
573608
var f_u32 : u32;
574609
var f_i32 : i32;

0 commit comments

Comments
 (0)