Skip to content
Open
93 changes: 92 additions & 1 deletion crates/rustc_codegen_spirv/src/linker/duplicates.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::custom_insts::{self, CustomOp};
use rspirv::binary::Assemble;
use rspirv::dr::{Instruction, Module, Operand};
use rspirv::spirv::{Op, Word};
use rspirv::spirv::{BuiltIn, Decoration, Op, StorageClass, Word};
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_middle::bug;
use smallvec::SmallVec;
Expand Down Expand Up @@ -104,6 +104,7 @@ fn gather_annotations(annotations: &[Instruction]) -> FxHashMap<Word, Vec<u32>>
.collect()
}

/// Returns a map from an ID to its debug name (given by `OpName`).
fn gather_names(debug_names: &[Instruction]) -> FxHashMap<Word, String> {
debug_names
.iter()
Expand Down Expand Up @@ -541,3 +542,93 @@ pub fn remove_duplicate_debuginfo(module: &mut Module) {
}
}
}

pub fn remove_duplicate_builtin_input_variables(module: &mut Module) {
// Find the variables decorated as input builtins, and any duplicates of them..

// Build a map: from a variable ID to the builtin it's decorated with.
let var_id_to_builtin: FxHashMap<Word, BuiltIn>;
{
let mut var_id_to_builtin_mut = FxHashMap::default();

for inst in module.annotations.iter() {
if inst.class.opcode == Op::Decorate
&& let [
Operand::IdRef(var_id),
Operand::Decoration(Decoration::BuiltIn),
Operand::BuiltIn(builtin),
] = inst.operands[..]
{
let prev = var_id_to_builtin_mut.insert(var_id, builtin);
assert!(
prev.is_none(),
"An OpVariable `{inst:?}` shouldn't have more than one Builtin decoration, \
but it has at least two: {builtin:?}, {prev:?}"
);
}
}
var_id_to_builtin = var_id_to_builtin_mut;
};

// Build a map from deleted duplicate input variable ID to the de-duplicated ID.
let duplicate_vars: FxHashMap<Word, Word>;
{
let mut duplicate_in_vars_mut = FxHashMap::<Word, Word>::default();

// Map from builtin to de-duped input variable ID.
let mut builtin_to_input_var_id = FxHashMap::<BuiltIn, Word>::default();

for inst in module.types_global_values.iter() {
if inst.class.opcode == Op::Variable
&& let [Operand::StorageClass(StorageClass::Input), ..] = inst.operands[..]
&& let Some(var_id) = inst.result_id
&& let Some(builtin) = var_id_to_builtin.get(&var_id)
{
match builtin_to_input_var_id.entry(*builtin) {
// first input variable we've seen for this builtin,
// record it in the builtins map.
hash_map::Entry::Vacant(vacant) => {
vacant.insert(var_id);
}

// this builtin already has an input variable,
// record it in the duplicates map.
hash_map::Entry::Occupied(occupied) => {
duplicate_in_vars_mut.insert(var_id, *occupied.get());
}
};
}
}
duplicate_vars = duplicate_in_vars_mut;
};

// Rewrite entry points
for inst in &mut module.entry_points {
rewrite_inst_with_rules(inst, &duplicate_vars);
}

// Rewrite function blocks to use de-duplicated variables.
for inst in &mut module
.functions
.iter_mut()
.flat_map(|f| &mut f.blocks)
.flat_map(|b| &mut b.instructions)
{
rewrite_inst_with_rules(inst, &duplicate_vars);
}

// Remove duplicate BuiltIn decorations
module.annotations.retain(|inst| {
!(inst.class.opcode == Op::Decorate
&& matches!(inst.operands[..], [
Operand::IdRef(var_id),
Operand::Decoration(Decoration::BuiltIn),
Operand::BuiltIn(_),
] if duplicate_vars.contains_key(&var_id)))
});

// Remove the duplicate variable definitions.
module
.types_global_values
.retain(|inst| !matches!(inst.result_id, Some(id) if duplicate_vars.contains_key(&id)));
}
1 change: 1 addition & 0 deletions crates/rustc_codegen_spirv/src/linker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ pub fn link(
duplicates::remove_duplicate_extensions(&mut output);
duplicates::remove_duplicate_capabilities(&mut output);
duplicates::remove_duplicate_ext_inst_imports(&mut output);
duplicates::remove_duplicate_builtin_input_variables(&mut output);
duplicates::remove_duplicate_types(&mut output);
// jb-todo: strip identical OpDecoration / OpDecorationGroups
}
Expand Down
143 changes: 143 additions & 0 deletions crates/spirv-std/src/builtin.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
//! Query SPIR-V read-only global built-in values
//!
//! Reference links:
//! * [WGSL specification describing builtins](https://www.w3.org/TR/WGSL/#builtin-inputs-outputs)
//! * [SPIR-V specification for builtins](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_builtin)
//! * [GLSL reference](https://registry.khronos.org/OpenGL-Refpages/gl4/)
//! * [GLSL reference source code](https://github.com/KhronosGroup/OpenGL-Refpages/tree/main/gl4)
//! * [GLSL extensions](https://github.com/KhronosGroup/GLSL/tree/main/extensions)

/// Load a BuiltIn by name with optional type
#[cfg(target_arch = "spirv")]
#[macro_export]
macro_rules! load_builtin {
($name:ident $(: $ty:ty)?) => {
unsafe {
let mut result $(: $ty)? = Default::default();
core::arch::asm! {
"%builtin = OpVariable typeof{result_ref} Input",
concat!("OpDecorate %builtin BuiltIn ", stringify!($name)),
"%result = OpLoad typeof*{result_ref} %builtin",
"OpStore {result_ref} %result",
result_ref = in(reg) &mut result,
}
result
}
};
}

/// Compute shader built-ins
pub mod compute {
use glam::UVec3;

// Local builtins (for this invocation's position in the workgroup).

/// The current invocation’s local invocation ID,
/// i.e. its position in the workgroup grid.
///
/// GLSL: `gl_LocalInvocationID`
/// WGSL: `local_invocation_id`
#[doc(alias = "gl_LocalInvocationID")]
#[inline]
#[gpu_only]
pub fn local_invocation_id() -> UVec3 {
load_builtin!(LocalInvocationId)
}

/// The current invocation’s local invocation index,
/// a linearized index of the invocation’s position within the workgroup grid.
///
/// GLSL: `gl_LocalInvocationIndex`
/// WGSL: `local_invocation_index`
#[doc(alias = "gl_LocalInvocationIndex")]
#[inline]
#[gpu_only]
pub fn local_invocation_index() -> u32 {
load_builtin!(LocalInvocationIndex)
}

// Global builtins, for this invocation's position in the compute grid.

/// The current invocation’s global invocation ID,
/// i.e. its position in the compute shader grid.
///
/// GLSL: `gl_GlobalInvocationID`
/// WGSL: `global_invocation_id`
#[doc(alias = "gl_GlobalInvocationID")]
#[inline]
#[gpu_only]
pub fn global_invocation_id() -> UVec3 {
load_builtin!(GlobalInvocationId)
}

// Subgroup builtins

/// The number of subgroups in the current invocation’s workgroup.
///
/// GLSL: [`gl_NumSubgroups`](https://github.com/KhronosGroup/GLSL/blob/main/extensions/khr/GL_KHR_shader_subgroup.txt)
/// WGSL: `num_subgroups`
#[doc(alias = "gl_NumSubgroups")]
#[inline]
#[gpu_only]
pub fn num_subgroups() -> u32 {
load_builtin!(NumSubgroups)
}

/// The subgroup ID of current invocation’s subgroup within the workgroup.
///
/// GLSL: [`gl_SubgroupID`](https://github.com/KhronosGroup/GLSL/blob/main/extensions/khr/GL_KHR_shader_subgroup.txt)
/// WGSL: `subgroup_id`
#[doc(alias = "gl_SubgroupID")]
#[inline]
#[gpu_only]
pub fn subgroup_id() -> u32 {
load_builtin!(SubgroupId)
}

/// This invocation's ID within its subgroup.
///
/// GLSL: [`gl_SubgroupInvocationID`](https://github.com/KhronosGroup/GLSL/blob/main/extensions/khr/GL_KHR_shader_subgroup.txt)
/// WGSL: `subgroup_invocation_id`
#[doc(alias = "gl_SubgroupInvocationID")]
#[inline]
#[gpu_only]
pub fn subgroup_invocation_id() -> u32 {
load_builtin!(SubgroupLocalInvocationId)
}

/// The subgroup size of current invocation’s subgroup.
///
/// GLSL: [`gl_SubgroupSize`](https://github.com/KhronosGroup/GLSL/blob/main/extensions/khr/GL_KHR_shader_subgroup.txt)
/// WGSL: `subgroup_size`
#[doc(alias = "gl_SubgroupInvocationID")]
#[inline]
#[gpu_only]
pub fn subgroup_size() -> u32 {
load_builtin!(SubgroupSize)
}

// Workgroup builtins

/// The number of workgroups that have been dispatched in the compute shader grid.
///
/// GLSL: `gl_NumWorkGroups`
/// WGSL: `num_workgroups`
#[doc(alias = "gl_WorkGroupID")]
#[inline]
#[gpu_only]
pub fn num_workgroups() -> UVec3 {
load_builtin!(NumWorkgroups)
}

/// The current invocation’s workgroup ID,
/// i.e. the position of the workgroup in the overall compute shader grid.
///
/// GLSL: `gl_WorkGroupID`
/// WGSL: `workgroup_id`
#[doc(alias = "gl_WorkGroupID")]
#[inline]
#[gpu_only]
pub fn workgroup_id() -> UVec3 {
load_builtin!(WorkgroupId)
}
}
1 change: 1 addition & 0 deletions crates/spirv-std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ pub use macros::spirv;
pub use macros::{debug_printf, debug_printfln};

pub mod arch;
pub mod builtin;
pub mod byte_addressable_buffer;
pub mod debug_printf;
pub mod float;
Expand Down
39 changes: 39 additions & 0 deletions tests/compiletests/ui/builtin/compute.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// build-pass
// compile-flags: -C llvm-args=--disassemble
// compile-flags: -C target-feature=+GroupNonUniform
// normalize-stderr-test "OpLine .*\n" -> ""
// normalize-stderr-test "OpSource .*\n" -> ""
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
// normalize-stderr-test "; .*\n" -> ""
// ignore-spv1.0
// ignore-spv1.1
// ignore-spv1.2
// ignore-spv1.3
// ignore-vulkan1.0
// ignore-vulkan1.1

use spirv_std::{builtin::compute, glam::*, spirv};

#[spirv(compute(threads(1)))]
pub fn compute(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut u32) {
// Local ID's
let _local_invocation_id: UVec3 = compute::local_invocation_id();
let local_invocation_index: u32 = compute::local_invocation_index();

// Global ID's
let _global_invocation_id: UVec3 = compute::global_invocation_id();

// Subgroup ID's
let _num_subgroups: u32 = compute::num_subgroups();
let _subgroup_id: u32 = compute::subgroup_id();
let _subgroup_invocation_id: u32 = compute::subgroup_invocation_id();
let _subgroup_size: u32 = compute::subgroup_size();

// Workgroup ID's
let _num_workgroups: UVec3 = compute::num_workgroups();
let _workgroup_id: UVec3 = compute::workgroup_id();

*out = local_invocation_index;
}
54 changes: 54 additions & 0 deletions tests/compiletests/ui/builtin/compute.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
OpCapability Shader
OpCapability GroupNonUniform
OpMemoryModel Logical Simple
OpEntryPoint GLCompute %1 "compute" %2 %3 %4 %5 %6 %7 %8 %9 %10 %11
OpExecutionMode %1 LocalSize 1 1 1
OpDecorate %14 Block
OpMemberDecorate %14 0 Offset 0
OpDecorate %2 Binding 0
OpDecorate %2 DescriptorSet 0
OpDecorate %3 BuiltIn LocalInvocationId
OpDecorate %4 BuiltIn LocalInvocationIndex
OpDecorate %5 BuiltIn GlobalInvocationId
OpDecorate %6 BuiltIn NumSubgroups
OpDecorate %7 BuiltIn SubgroupId
OpDecorate %8 BuiltIn SubgroupLocalInvocationId
OpDecorate %9 BuiltIn SubgroupSize
OpDecorate %10 BuiltIn NumWorkgroups
OpDecorate %11 BuiltIn WorkgroupId
%15 = OpTypeInt 32 0
%14 = OpTypeStruct %15
%16 = OpTypePointer StorageBuffer %14
%17 = OpTypeVoid
%18 = OpTypeFunction %17
%19 = OpTypePointer StorageBuffer %15
%2 = OpVariable %16 StorageBuffer
%20 = OpConstant %15 0
%21 = OpTypeVector %15 3
%22 = OpTypePointer Input %21
%3 = OpVariable %22 Input
%23 = OpTypePointer Input %15
%4 = OpVariable %23 Input
%5 = OpVariable %22 Input
%6 = OpVariable %23 Input
%7 = OpVariable %23 Input
%8 = OpVariable %23 Input
%9 = OpVariable %23 Input
%10 = OpVariable %22 Input
%11 = OpVariable %22 Input
%1 = OpFunction %17 None %18
%24 = OpLabel
%25 = OpInBoundsAccessChain %19 %2 %20
%26 = OpLoad %21 %3
%27 = OpLoad %15 %4
%28 = OpLoad %21 %5
%29 = OpLoad %15 %6
%30 = OpLoad %15 %7
%31 = OpLoad %15 %8
%32 = OpLoad %15 %9
%33 = OpLoad %21 %10
%34 = OpLoad %21 %11
OpStore %25 %27
OpNoLine
OpReturn
OpFunctionEnd
31 changes: 31 additions & 0 deletions tests/compiletests/ui/builtin/compute_attr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// build-pass
// compile-flags: -C llvm-args=--disassemble
// normalize-stderr-test "OpLine .*\n" -> ""
// normalize-stderr-test "OpSource .*\n" -> ""
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
// normalize-stderr-test "; .*\n" -> ""
// ignore-spv1.0
// ignore-spv1.1
// ignore-spv1.2
// ignore-spv1.3
// ignore-vulkan1.0
// ignore-vulkan1.1

use spirv_std::glam::*;
use spirv_std::spirv;

#[spirv(compute(threads(1)))]
pub fn compute(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut u32,
// #[spirv(global_invocation_id)] global_invocation_id: UVec3,
#[spirv(local_invocation_id)] local_invocation_id: UVec3,
// #[spirv(subgroup_local_invocation_id)] subgroup_local_invocation_id: u32,
// #[spirv(num_subgroups)] num_subgroups: u32,
// #[spirv(num_workgroups)] num_workgroups: UVec3,
// #[spirv(subgroup_id)] subgroup_id: u32,
// #[spirv(workgroup_id)] workgroup_id: UVec3,
) {
*out = local_invocation_id.x;
}
Loading
Loading