diff --git a/src/webgpu/capability_info.ts b/src/webgpu/capability_info.ts index 297fa55ccbe2..e2cf0e71400d 100644 --- a/src/webgpu/capability_info.ts +++ b/src/webgpu/capability_info.ts @@ -981,6 +981,7 @@ export const kKnownWGSLLanguageFeatures = [ 'subgroup_id', 'subgroup_uniformity', 'swizzle_assignment', + 'linear_indexing', ] as const; export type WGSLLanguageFeature = (typeof kKnownWGSLLanguageFeatures)[number]; diff --git a/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts b/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts index 2d57a0f4fb1f..08183a3626c2 100644 --- a/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts +++ b/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts @@ -34,6 +34,7 @@ g.test('inputs') .beginSubcases() ) .fn(t => { + const linear_indexing = t.hasLanguageFeature('linear_indexing'); const invocationsPerGroup = t.params.groupSize.x * t.params.groupSize.y * t.params.groupSize.z; const totalInvocations = invocationsPerGroup * t.params.numGroups.x * t.params.numGroups.y * t.params.numGroups.z; @@ -46,6 +47,8 @@ g.test('inputs') let global_id = ''; let group_id = ''; let num_groups = ''; + let global_index = ''; + let group_index = ''; switch (t.params.method) { case 'param': params = ` @@ -54,12 +57,18 @@ g.test('inputs') @builtin(global_invocation_id) global_id : vec3, @builtin(workgroup_id) group_id : vec3, @builtin(num_workgroups) num_groups : vec3, + ${linear_indexing ? '@builtin(global_invocation_index) global_index : u32,' : ''} + ${linear_indexing ? '@builtin(workgroup_index) group_index : u32,' : ''} `; local_id = 'local_id'; local_index = 'local_index'; global_id = 'global_id'; group_id = 'group_id'; num_groups = 'num_groups'; + if (linear_indexing) { + global_index = 'global_index'; + group_index = 'group_index'; + } break; case 'struct': structures = `struct Inputs { @@ -68,6 +77,8 @@ g.test('inputs') @builtin(global_invocation_id) global_id : vec3, @builtin(workgroup_id) group_id : vec3, @builtin(num_workgroups) num_groups : vec3, + ${linear_indexing ? '@builtin(global_invocation_index) global_index : u32,' : ''} + ${linear_indexing ? '@builtin(workgroup_index) group_index : u32,' : ''} };`; params = `inputs : Inputs`; local_id = 'inputs.local_id'; @@ -75,6 +86,10 @@ g.test('inputs') global_id = 'inputs.global_id'; group_id = 'inputs.group_id'; num_groups = 'inputs.num_groups'; + if (linear_indexing) { + global_index = 'inputs.global_index'; + group_index = 'inputs.group_index'; + } break; case 'mixed': structures = `struct InputsA { @@ -87,12 +102,19 @@ g.test('inputs') params = `@builtin(local_invocation_id) local_id : vec3, inputsA : InputsA, inputsB : InputsB, - @builtin(num_workgroups) num_groups : vec3,`; + @builtin(num_workgroups) num_groups : vec3, + ${linear_indexing ? '@builtin(global_invocation_index) global_index : u32,' : ''} + ${linear_indexing ? '@builtin(workgroup_index) group_index : u32,' : ''} + `; local_id = 'local_id'; local_index = 'inputsA.local_index'; global_id = 'inputsA.global_id'; group_id = 'inputsB.group_id'; num_groups = 'num_groups'; + if (linear_indexing) { + global_index = 'global_index'; + group_index = 'group_index'; + } break; } @@ -104,6 +126,8 @@ g.test('inputs') global_id: vec3u, group_id: vec3u, num_groups: vec3u, + ${linear_indexing ? 'global_index : u32,' : ''} + ${linear_indexing ? 'group_index : u32,' : ''} }; @group(0) @binding(0) var outputs : array; @@ -117,15 +141,17 @@ g.test('inputs') fn main( ${params} ) { - let group_index = ((${group_id}.z * ${num_groups}.y) + ${group_id}.y) * ${num_groups}.x + ${group_id}.x; - let global_index = group_index * ${invocationsPerGroup}u + ${local_index}; + let o_group_index = ((${group_id}.z * ${num_groups}.y) + ${group_id}.y) * ${num_groups}.x + ${group_id}.x; + let o_global_index = o_group_index * ${invocationsPerGroup}u + ${local_index}; var o: Outputs; o.local_id = ${local_id}; o.local_index = ${local_index}; o.global_id = ${global_id}; o.group_id = ${group_id}; o.num_groups = ${num_groups}; - outputs[global_index] = o; + ${linear_indexing ? `o.global_index = ${global_index};` : ``} + ${linear_indexing ? `o.group_index = ${group_index};` : ``} + outputs[o_global_index] = o; } `; @@ -145,7 +171,9 @@ g.test('inputs') const kGlobalIdOffset = 4; const kGroupIdOffset = 8; const kNumGroupsOffset = 12; - const kOutputElementSize = 16; + const kGlobalIndexOffset = 15; + const kGroupIndexOffset = 16; + const kOutputElementSize = linear_indexing ? 20 : 16; // Create the output buffers. const outputBuffer = t.createBufferTracked({ @@ -203,6 +231,21 @@ g.test('inputs') const localIndex = (lz * t.params.groupSize.y + ly) * t.params.groupSize.x + lx; const globalIndex = groupIndex * invocationsPerGroup + localIndex; const globalOffset = globalIndex * kOutputElementSize; + const gidX = gx * t.params.groupSize.x + lx; + const gidY = gy * t.params.groupSize.y + ly; + const gidZ = gz * t.params.groupSize.z + lz; + const globalLinearIndex = + gidX + + gidY * t.params.groupSize.x * t.params.numGroups.x + + gidZ * + t.params.groupSize.x * + t.params.numGroups.x * + t.params.groupSize.y * + t.params.numGroups.y; + const groupLinearIndex = + gx + + gy * t.params.numGroups.x + + gz * t.params.numGroups.x * t.params.numGroups.y; const expectEqual = (name: string, expected: number, actual: number) => { if (actual !== expected) { @@ -226,17 +269,23 @@ g.test('inputs') const error = checkVec3Value('local_id', kLocalIdOffset, { x: lx, y: ly, z: lz }) || - checkVec3Value('global_id', kGlobalIdOffset, { - x: gx * t.params.groupSize.x + lx, - y: gy * t.params.groupSize.y + ly, - z: gz * t.params.groupSize.z + lz, - }) || + checkVec3Value('global_id', kGlobalIdOffset, { x: gidX, y: gidY, z: gidZ }) || checkVec3Value('group_id', kGroupIdOffset, { x: gx, y: gy, z: gz }) || checkVec3Value('num_groups', kNumGroupsOffset, t.params.numGroups) || expectEqual( 'local_index', localIndex, output[globalOffset + kLocalIndexOffset] + ) || + expectEqual( + 'global_index', + globalLinearIndex, + output[globalOffset + kGlobalIndexOffset] + ) || + expectEqual( + 'group_index', + groupLinearIndex, + output[globalOffset + kGroupIndexOffset] ); if (error) { return error; diff --git a/src/webgpu/shader/validation/shader_io/builtins.spec.ts b/src/webgpu/shader/validation/shader_io/builtins.spec.ts index e7d5b1070823..1c4e11aafb0e 100644 --- a/src/webgpu/shader/validation/shader_io/builtins.spec.ts +++ b/src/webgpu/shader/validation/shader_io/builtins.spec.ts @@ -119,6 +119,20 @@ export const kBuiltins: readonly Builtin[] = [ enable: 'subgroups', requires: 'subgroup_id', }, + { + name: 'workgroup_index', + stage: 'compute', + io: 'in', + type: 'u32', + requires: 'linear_indexing', + }, + { + name: 'global_invocation_index', + stage: 'compute', + io: 'in', + type: 'u32', + requires: 'linear_indexing', + }, ] as const; // List of types to test against.