@@ -34,6 +34,7 @@ g.test('inputs')
3434 . beginSubcases ( )
3535 )
3636 . fn ( t => {
37+ const linear_indexing = t . hasLanguageFeature ( 'linear_indexing' ) ;
3738 const invocationsPerGroup = t . params . groupSize . x * t . params . groupSize . y * t . params . groupSize . z ;
3839 const totalInvocations =
3940 invocationsPerGroup * t . params . numGroups . x * t . params . numGroups . y * t . params . numGroups . z ;
@@ -46,6 +47,8 @@ g.test('inputs')
4647 let global_id = '' ;
4748 let group_id = '' ;
4849 let num_groups = '' ;
50+ let global_index = '' ;
51+ let group_index = '' ;
4952 switch ( t . params . method ) {
5053 case 'param' :
5154 params = `
@@ -54,12 +57,18 @@ g.test('inputs')
5457 @builtin(global_invocation_id) global_id : vec3<u32>,
5558 @builtin(workgroup_id) group_id : vec3<u32>,
5659 @builtin(num_workgroups) num_groups : vec3<u32>,
60+ ${ linear_indexing ? '@builtin(global_invocation_index) global_index : u32,' : '' }
61+ ${ linear_indexing ? '@builtin(workgroup_index) group_index : u32,' : '' }
5762 ` ;
5863 local_id = 'local_id' ;
5964 local_index = 'local_index' ;
6065 global_id = 'global_id' ;
6166 group_id = 'group_id' ;
6267 num_groups = 'num_groups' ;
68+ if ( linear_indexing ) {
69+ global_index = 'global_index' ;
70+ group_index = 'group_index' ;
71+ }
6372 break ;
6473 case 'struct' :
6574 structures = `struct Inputs {
@@ -68,13 +77,19 @@ g.test('inputs')
6877 @builtin(global_invocation_id) global_id : vec3<u32>,
6978 @builtin(workgroup_id) group_id : vec3<u32>,
7079 @builtin(num_workgroups) num_groups : vec3<u32>,
80+ ${ linear_indexing ? '@builtin(global_invocation_index) global_index : u32,' : '' }
81+ ${ linear_indexing ? '@builtin(workgroup_index) group_index : u32,' : '' }
7182 };` ;
7283 params = `inputs : Inputs` ;
7384 local_id = 'inputs.local_id' ;
7485 local_index = 'inputs.local_index' ;
7586 global_id = 'inputs.global_id' ;
7687 group_id = 'inputs.group_id' ;
7788 num_groups = 'inputs.num_groups' ;
89+ if ( linear_indexing ) {
90+ global_index = 'inputs.global_index' ;
91+ group_index = 'inputs.group_index' ;
92+ }
7893 break ;
7994 case 'mixed' :
8095 structures = `struct InputsA {
@@ -87,12 +102,19 @@ g.test('inputs')
87102 params = `@builtin(local_invocation_id) local_id : vec3<u32>,
88103 inputsA : InputsA,
89104 inputsB : InputsB,
90- @builtin(num_workgroups) num_groups : vec3<u32>,` ;
105+ @builtin(num_workgroups) num_groups : vec3<u32>,
106+ ${ linear_indexing ? '@builtin(global_invocation_index) global_index : u32,' : '' }
107+ ${ linear_indexing ? '@builtin(workgroup_index) group_index : u32,' : '' }
108+ ` ;
91109 local_id = 'local_id' ;
92110 local_index = 'inputsA.local_index' ;
93111 global_id = 'inputsA.global_id' ;
94112 group_id = 'inputsB.group_id' ;
95113 num_groups = 'num_groups' ;
114+ if ( linear_indexing ) {
115+ global_index = 'global_index' ;
116+ group_index = 'group_index' ;
117+ }
96118 break ;
97119 }
98120
@@ -104,6 +126,8 @@ g.test('inputs')
104126 global_id: vec3u,
105127 group_id: vec3u,
106128 num_groups: vec3u,
129+ ${ linear_indexing ? 'global_index : u32,' : '' }
130+ ${ linear_indexing ? 'group_index : u32,' : '' }
107131 };
108132 @group(0) @binding(0) var<storage, read_write> outputs : array<Outputs>;
109133
@@ -117,15 +141,17 @@ g.test('inputs')
117141 fn main(
118142 ${ params }
119143 ) {
120- let group_index = ((${ group_id } .z * ${ num_groups } .y) + ${ group_id } .y) * ${ num_groups } .x + ${ group_id } .x;
121- let global_index = group_index * ${ invocationsPerGroup } u + ${ local_index } ;
144+ let o_group_index = ((${ group_id } .z * ${ num_groups } .y) + ${ group_id } .y) * ${ num_groups } .x + ${ group_id } .x;
145+ let o_global_index = o_group_index * ${ invocationsPerGroup } u + ${ local_index } ;
122146 var o: Outputs;
123147 o.local_id = ${ local_id } ;
124148 o.local_index = ${ local_index } ;
125149 o.global_id = ${ global_id } ;
126150 o.group_id = ${ group_id } ;
127151 o.num_groups = ${ num_groups } ;
128- outputs[global_index] = o;
152+ ${ linear_indexing ? `o.global_index = ${ global_index } ;` : `` }
153+ ${ linear_indexing ? `o.group_index = ${ group_index } ;` : `` }
154+ outputs[o_global_index] = o;
129155 }
130156 ` ;
131157
@@ -145,7 +171,9 @@ g.test('inputs')
145171 const kGlobalIdOffset = 4 ;
146172 const kGroupIdOffset = 8 ;
147173 const kNumGroupsOffset = 12 ;
148- const kOutputElementSize = 16 ;
174+ const kGlobalIndexOffset = 15 ;
175+ const kGroupIndexOffset = 16 ;
176+ const kOutputElementSize = linear_indexing ? 20 : 16 ;
149177
150178 // Create the output buffers.
151179 const outputBuffer = t . createBufferTracked ( {
@@ -203,6 +231,21 @@ g.test('inputs')
203231 const localIndex = ( lz * t . params . groupSize . y + ly ) * t . params . groupSize . x + lx ;
204232 const globalIndex = groupIndex * invocationsPerGroup + localIndex ;
205233 const globalOffset = globalIndex * kOutputElementSize ;
234+ const gidX = gx * t . params . groupSize . x + lx ;
235+ const gidY = gy * t . params . groupSize . y + ly ;
236+ const gidZ = gz * t . params . groupSize . z + lz ;
237+ const globalLinearIndex =
238+ gidX +
239+ gidY * t . params . groupSize . x * t . params . numGroups . x +
240+ gidZ *
241+ t . params . groupSize . x *
242+ t . params . numGroups . x *
243+ t . params . groupSize . y *
244+ t . params . numGroups . y ;
245+ const groupLinearIndex =
246+ gx +
247+ gy * t . params . numGroups . x +
248+ gz * t . params . numGroups . x * t . params . numGroups . y ;
206249
207250 const expectEqual = ( name : string , expected : number , actual : number ) => {
208251 if ( actual !== expected ) {
@@ -226,17 +269,23 @@ g.test('inputs')
226269
227270 const error =
228271 checkVec3Value ( 'local_id' , kLocalIdOffset , { x : lx , y : ly , z : lz } ) ||
229- checkVec3Value ( 'global_id' , kGlobalIdOffset , {
230- x : gx * t . params . groupSize . x + lx ,
231- y : gy * t . params . groupSize . y + ly ,
232- z : gz * t . params . groupSize . z + lz ,
233- } ) ||
272+ checkVec3Value ( 'global_id' , kGlobalIdOffset , { x : gidX , y : gidY , z : gidZ } ) ||
234273 checkVec3Value ( 'group_id' , kGroupIdOffset , { x : gx , y : gy , z : gz } ) ||
235274 checkVec3Value ( 'num_groups' , kNumGroupsOffset , t . params . numGroups ) ||
236275 expectEqual (
237276 'local_index' ,
238277 localIndex ,
239278 output [ globalOffset + kLocalIndexOffset ]
279+ ) ||
280+ expectEqual (
281+ 'global_index' ,
282+ globalLinearIndex ,
283+ output [ globalOffset + kGlobalIndexOffset ]
284+ ) ||
285+ expectEqual (
286+ 'group_index' ,
287+ groupLinearIndex ,
288+ output [ globalOffset + kGroupIndexOffset ]
240289 ) ;
241290 if ( error ) {
242291 return error ;
0 commit comments