Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/webgpu/capability_info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,7 @@ export const kKnownWGSLLanguageFeatures = [
'subgroup_id',
'subgroup_uniformity',
'swizzle_assignment',
'linear_indexing',
] as const;

export type WGSLLanguageFeature = (typeof kKnownWGSLLanguageFeatures)[number];
69 changes: 59 additions & 10 deletions src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ g.test('inputs')
.beginSubcases()
)
.fn(t => {
const linear_indexing = t.hasLanguageFeature('linear_indexing');
const invocationsPerGroup = t.params.groupSize.x * t.params.groupSize.y * t.params.groupSize.z;
const totalInvocations =
invocationsPerGroup * t.params.numGroups.x * t.params.numGroups.y * t.params.numGroups.z;
Expand All @@ -46,6 +47,8 @@ g.test('inputs')
let global_id = '';
let group_id = '';
let num_groups = '';
let global_index = '';
let group_index = '';
switch (t.params.method) {
case 'param':
params = `
Expand All @@ -54,12 +57,18 @@ g.test('inputs')
@builtin(global_invocation_id) global_id : vec3<u32>,
@builtin(workgroup_id) group_id : vec3<u32>,
@builtin(num_workgroups) num_groups : vec3<u32>,
${linear_indexing ? '@builtin(global_invocation_index) global_index : u32,' : ''}
${linear_indexing ? '@builtin(workgroup_index) group_index : u32,' : ''}
`;
local_id = 'local_id';
local_index = 'local_index';
global_id = 'global_id';
group_id = 'group_id';
num_groups = 'num_groups';
if (linear_indexing) {
global_index = 'global_index';
group_index = 'group_index';
}
break;
case 'struct':
structures = `struct Inputs {
Expand All @@ -68,13 +77,19 @@ g.test('inputs')
@builtin(global_invocation_id) global_id : vec3<u32>,
@builtin(workgroup_id) group_id : vec3<u32>,
@builtin(num_workgroups) num_groups : vec3<u32>,
${linear_indexing ? '@builtin(global_invocation_index) global_index : u32,' : ''}
${linear_indexing ? '@builtin(workgroup_index) group_index : u32,' : ''}
};`;
params = `inputs : Inputs`;
local_id = 'inputs.local_id';
local_index = 'inputs.local_index';
global_id = 'inputs.global_id';
group_id = 'inputs.group_id';
num_groups = 'inputs.num_groups';
if (linear_indexing) {
global_index = 'inputs.global_index';
group_index = 'inputs.group_index';
}
break;
case 'mixed':
structures = `struct InputsA {
Expand All @@ -87,12 +102,19 @@ g.test('inputs')
params = `@builtin(local_invocation_id) local_id : vec3<u32>,
inputsA : InputsA,
inputsB : InputsB,
@builtin(num_workgroups) num_groups : vec3<u32>,`;
@builtin(num_workgroups) num_groups : vec3<u32>,
${linear_indexing ? '@builtin(global_invocation_index) global_index : u32,' : ''}
${linear_indexing ? '@builtin(workgroup_index) group_index : u32,' : ''}
`;
local_id = 'local_id';
local_index = 'inputsA.local_index';
global_id = 'inputsA.global_id';
group_id = 'inputsB.group_id';
num_groups = 'num_groups';
if (linear_indexing) {
global_index = 'global_index';
group_index = 'group_index';
}
break;
}

Expand All @@ -104,6 +126,8 @@ g.test('inputs')
global_id: vec3u,
group_id: vec3u,
num_groups: vec3u,
${linear_indexing ? 'global_index : u32,' : ''}
${linear_indexing ? 'group_index : u32,' : ''}
};
@group(0) @binding(0) var<storage, read_write> outputs : array<Outputs>;

Expand All @@ -117,15 +141,17 @@ g.test('inputs')
fn main(
${params}
) {
let group_index = ((${group_id}.z * ${num_groups}.y) + ${group_id}.y) * ${num_groups}.x + ${group_id}.x;
let global_index = group_index * ${invocationsPerGroup}u + ${local_index};
let o_group_index = ((${group_id}.z * ${num_groups}.y) + ${group_id}.y) * ${num_groups}.x + ${group_id}.x;
let o_global_index = o_group_index * ${invocationsPerGroup}u + ${local_index};
var o: Outputs;
o.local_id = ${local_id};
o.local_index = ${local_index};
o.global_id = ${global_id};
o.group_id = ${group_id};
o.num_groups = ${num_groups};
outputs[global_index] = o;
${linear_indexing ? `o.global_index = ${global_index};` : ``}
${linear_indexing ? `o.group_index = ${group_index};` : ``}
outputs[o_global_index] = o;
}
`;

Expand All @@ -145,7 +171,9 @@ g.test('inputs')
const kGlobalIdOffset = 4;
const kGroupIdOffset = 8;
const kNumGroupsOffset = 12;
const kOutputElementSize = 16;
const kGlobalIndexOffset = 15;
const kGroupIndexOffset = 16;
const kOutputElementSize = linear_indexing ? 20 : 16;

// Create the output buffers.
const outputBuffer = t.createBufferTracked({
Expand Down Expand Up @@ -203,6 +231,21 @@ g.test('inputs')
const localIndex = (lz * t.params.groupSize.y + ly) * t.params.groupSize.x + lx;
const globalIndex = groupIndex * invocationsPerGroup + localIndex;
const globalOffset = globalIndex * kOutputElementSize;
const gidX = gx * t.params.groupSize.x + lx;
const gidY = gy * t.params.groupSize.y + ly;
const gidZ = gz * t.params.groupSize.z + lz;
const globalLinearIndex =
gidX +
gidY * t.params.groupSize.x * t.params.numGroups.x +
gidZ *
t.params.groupSize.x *
t.params.numGroups.x *
t.params.groupSize.y *
t.params.numGroups.y;
const groupLinearIndex =
gx +
gy * t.params.numGroups.x +
gz * t.params.numGroups.x * t.params.numGroups.y;

const expectEqual = (name: string, expected: number, actual: number) => {
if (actual !== expected) {
Expand All @@ -226,17 +269,23 @@ g.test('inputs')

const error =
checkVec3Value('local_id', kLocalIdOffset, { x: lx, y: ly, z: lz }) ||
checkVec3Value('global_id', kGlobalIdOffset, {
x: gx * t.params.groupSize.x + lx,
y: gy * t.params.groupSize.y + ly,
z: gz * t.params.groupSize.z + lz,
}) ||
checkVec3Value('global_id', kGlobalIdOffset, { x: gidX, y: gidY, z: gidZ }) ||
checkVec3Value('group_id', kGroupIdOffset, { x: gx, y: gy, z: gz }) ||
checkVec3Value('num_groups', kNumGroupsOffset, t.params.numGroups) ||
expectEqual(
'local_index',
localIndex,
output[globalOffset + kLocalIndexOffset]
) ||
expectEqual(
'global_index',
globalLinearIndex,
output[globalOffset + kGlobalIndexOffset]
) ||
expectEqual(
'group_index',
groupLinearIndex,
output[globalOffset + kGroupIndexOffset]
);
if (error) {
return error;
Expand Down
14 changes: 14 additions & 0 deletions src/webgpu/shader/validation/shader_io/builtins.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ export const kBuiltins: readonly Builtin[] = [
enable: 'subgroups',
requires: 'subgroup_id',
},
{
name: 'workgroup_index',
stage: 'compute',
io: 'in',
type: 'u32',
requires: 'linear_indexing',
},
{
name: 'global_invocation_index',
stage: 'compute',
io: 'in',
type: 'u32',
requires: 'linear_indexing',
},
] as const;

// List of types to test against.
Expand Down