Skip to content

Commit 42a3a3f

Browse files
committed
Auto merge of #146181 - Flakebi:dynamic-shared-memory, r=<try>
Add intrinsic for launch-sized workgroup memory on GPUs try-job: x86_64-gnu-nopt try-job: x86_64-gnu-debug
2 parents f889772 + 3541dd4 commit 42a3a3f

11 files changed

Lines changed: 168 additions & 7 deletions

File tree

compiler/rustc_abi/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,6 +1702,9 @@ pub struct AddressSpace(pub u32);
17021702
impl AddressSpace {
17031703
/// LLVM's `0` address space.
17041704
pub const ZERO: Self = AddressSpace(0);
1705+
/// The address space for workgroup memory on nvptx and amdgpu.
1706+
/// See e.g. the `gpu_launch_sized_workgroup_mem` intrinsic for details.
1707+
pub const GPU_WORKGROUP: Self = AddressSpace(3);
17051708
}
17061709

17071710
/// The way we represent values to the backend

compiler/rustc_codegen_llvm/src/declare.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
use std::borrow::Borrow;
1515

1616
use itertools::Itertools;
17+
use rustc_abi::AddressSpace;
1718
use rustc_codegen_ssa::traits::TypeMembershipCodegenMethods;
1819
use rustc_data_structures::fx::FxIndexSet;
1920
use rustc_middle::ty::{Instance, Ty};
@@ -97,6 +98,28 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
9798
)
9899
}
99100
}
101+
102+
/// Declare a global value in a specific address space.
103+
///
104+
/// If there’s a value with the same name already declared, the function will
105+
/// return its Value instead.
106+
pub(crate) fn declare_global_in_addrspace(
107+
&self,
108+
name: &str,
109+
ty: &'ll Type,
110+
addr_space: AddressSpace,
111+
) -> &'ll Value {
112+
debug!("declare_global(name={name:?}, addrspace={addr_space:?})");
113+
unsafe {
114+
llvm::LLVMRustGetOrInsertGlobalInAddrspace(
115+
(**self).borrow().llmod,
116+
name.as_c_char_ptr(),
117+
name.len(),
118+
ty,
119+
addr_space.0,
120+
)
121+
}
122+
}
100123
}
101124

102125
impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ use std::ffi::c_uint;
33
use std::ptr;
44

55
use rustc_abi::{
6-
Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size, WrappingRange,
6+
AddressSpace, Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size,
7+
WrappingRange,
78
};
89
use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh};
910
use rustc_codegen_ssa::codegen_attrs::autodiff_attrs;
@@ -24,7 +25,7 @@ use rustc_session::config::CrateType;
2425
use rustc_span::{Span, Symbol, sym};
2526
use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate};
2627
use rustc_target::callconv::PassMode;
27-
use rustc_target::spec::Os;
28+
use rustc_target::spec::{Arch, Os};
2829
use tracing::debug;
2930

3031
use crate::abi::FnAbiLlvmExt;
@@ -554,6 +555,44 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
554555
return Ok(());
555556
}
556557

558+
sym::gpu_launch_sized_workgroup_mem => {
559+
// The name of the global variable is not relevant, the important properties are.
560+
// 1. The global is in the address space for workgroup memory
561+
// 2. It is an extern global
562+
// All instances of extern addrspace(gpu_workgroup) globals are merged in the LLVM backend.
563+
// Generate an unnamed global per intrinsic call, so that different kernels can have
564+
// different minimum alignments.
565+
// See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared
566+
// FIXME Workaround an nvptx backend issue that extern globals must have a name
567+
let name = if tcx.sess.target.arch == Arch::Nvptx64 {
568+
"gpu_launch_sized_workgroup_mem"
569+
} else {
570+
""
571+
};
572+
let global = self.declare_global_in_addrspace(
573+
name,
574+
self.type_array(self.type_i8(), 0),
575+
AddressSpace::GPU_WORKGROUP,
576+
);
577+
let ty::RawPtr(inner_ty, _) = result.layout.ty.kind() else { unreachable!() };
578+
// The alignment of the global is used to specify the *minimum* alignment that
579+
// must be obeyed by the GPU runtime.
580+
// When multiple of these global variables are used by a kernel, the maximum alignment is taken.
581+
// See https://github.com/llvm/llvm-project/blob/a271d07488a85ce677674bbe8101b10efff58c95/llvm/lib/Target/AMDGPU/AMDGPULowerModuleLDSPass.cpp#L821
582+
let alignment = self.align_of(*inner_ty).bytes() as u32;
583+
unsafe {
584+
// FIXME Workaround the above issue by taking maximum alignment if the global existed
585+
if tcx.sess.target.arch == Arch::Nvptx64 {
586+
if alignment > llvm::LLVMGetAlignment(global) {
587+
llvm::LLVMSetAlignment(global, alignment);
588+
}
589+
} else {
590+
llvm::LLVMSetAlignment(global, alignment);
591+
}
592+
}
593+
self.cx().const_pointercast(global, self.type_ptr())
594+
}
595+
557596
sym::amdgpu_dispatch_ptr => {
558597
let val = self.call_intrinsic("llvm.amdgcn.dispatch.ptr", &[], &[]);
559598
// Relying on `LLVMBuildPointerCast` to produce an addrspacecast

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,6 +1989,13 @@ unsafe extern "C" {
19891989
NameLen: size_t,
19901990
T: &'a Type,
19911991
) -> &'a Value;
1992+
pub(crate) fn LLVMRustGetOrInsertGlobalInAddrspace<'a>(
1993+
M: &'a Module,
1994+
Name: *const c_char,
1995+
NameLen: size_t,
1996+
T: &'a Type,
1997+
AddressSpace: c_uint,
1998+
) -> &'a Value;
19921999
pub(crate) fn LLVMRustGetNamedValue(
19932000
M: &Module,
19942001
Name: *const c_char,

compiler/rustc_codegen_ssa/src/mir/intrinsic.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
111111
sym::abort
112112
| sym::unreachable
113113
| sym::cold_path
114+
| sym::gpu_launch_sized_workgroup_mem
114115
| sym::breakpoint
115116
| sym::amdgpu_dispatch_ptr
116117
| sym::assert_zero_valid

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi
133133
| sym::forget
134134
| sym::frem_algebraic
135135
| sym::fsub_algebraic
136+
| sym::gpu_launch_sized_workgroup_mem
136137
| sym::is_val_statically_known
137138
| sym::log2f16
138139
| sym::log2f32
@@ -298,6 +299,7 @@ pub(crate) fn check_intrinsic_type(
298299
sym::offset_of => (1, 0, vec![tcx.types.u32, tcx.types.u32], tcx.types.usize),
299300
sym::rustc_peek => (1, 0, vec![param(0)], param(0)),
300301
sym::caller_location => (0, 0, vec![], tcx.caller_location_ty()),
302+
sym::gpu_launch_sized_workgroup_mem => (1, 0, vec![], Ty::new_mut_ptr(tcx, param(0))),
301303
sym::assert_inhabited | sym::assert_zero_valid | sym::assert_mem_uninitialized_valid => {
302304
(1, 0, vec![], tcx.types.unit)
303305
}

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,10 @@ extern "C" LLVMValueRef LLVMRustGetOrInsertFunction(LLVMModuleRef M,
298298
.getCallee());
299299
}
300300

301-
extern "C" LLVMValueRef LLVMRustGetOrInsertGlobal(LLVMModuleRef M,
302-
const char *Name,
303-
size_t NameLen,
304-
LLVMTypeRef Ty) {
301+
extern "C" LLVMValueRef
302+
LLVMRustGetOrInsertGlobalInAddrspace(LLVMModuleRef M, const char *Name,
303+
size_t NameLen, LLVMTypeRef Ty,
304+
unsigned AddressSpace) {
305305
Module *Mod = unwrap(M);
306306
auto NameRef = StringRef(Name, NameLen);
307307

@@ -312,10 +312,21 @@ extern "C" LLVMValueRef LLVMRustGetOrInsertGlobal(LLVMModuleRef M,
312312
GlobalVariable *GV = Mod->getGlobalVariable(NameRef, true);
313313
if (!GV)
314314
GV = new GlobalVariable(*Mod, unwrap(Ty), false,
315-
GlobalValue::ExternalLinkage, nullptr, NameRef);
315+
GlobalValue::ExternalLinkage, nullptr, NameRef,
316+
nullptr, GlobalValue::NotThreadLocal, AddressSpace);
316317
return wrap(GV);
317318
}
318319

320+
extern "C" LLVMValueRef LLVMRustGetOrInsertGlobal(LLVMModuleRef M,
321+
const char *Name,
322+
size_t NameLen,
323+
LLVMTypeRef Ty) {
324+
Module *Mod = unwrap(M);
325+
unsigned AddressSpace = Mod->getDataLayout().getDefaultGlobalsAddressSpace();
326+
return LLVMRustGetOrInsertGlobalInAddrspace(M, Name, NameLen, Ty,
327+
AddressSpace);
328+
}
329+
319330
// Must match the layout of `rustc_codegen_llvm::llvm::ffi::AttributeKind`.
320331
enum class LLVMRustAttributeKind {
321332
AlwaysInline = 0,

compiler/rustc_span/src/symbol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,6 +1179,7 @@ symbols! {
11791179
global_asm,
11801180
global_registration,
11811181
globs,
1182+
gpu_launch_sized_workgroup_mem,
11821183
gt,
11831184
guard_patterns,
11841185
half_open_range_patterns,

library/core/src/intrinsics/mod.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3457,6 +3457,45 @@ pub(crate) const fn miri_promise_symbolic_alignment(ptr: *const (), align: usize
34573457
)
34583458
}
34593459

3460+
/// Returns the pointer to workgroup memory allocated at launch-time on GPUs.
3461+
///
3462+
/// Workgroup memory is a memory region that is shared between all threads in
3463+
/// the same workgroup. It is faster to access than other memory but pointers do not
3464+
/// work outside the workgroup where they were obtained.
3465+
/// Workgroup memory can be allocated statically or after compilation, when
3466+
/// launching a gpu-kernel. `gpu_launch_sized_workgroup_mem` returns the pointer to
3467+
/// the memory that is allocated at launch-time.
3468+
/// The size of this memory can differ between launches of a gpu-kernel, depending on
3469+
/// what is specified at launch-time.
3470+
/// However, the alignment is fixed by the kernel itself, at compile-time.
3471+
///
3472+
/// The returned pointer is the start of the workgroup memory region that is
3473+
/// allocated at launch-time.
3474+
/// All calls to `gpu_launch_sized_workgroup_mem` in a workgroup, independent of the
3475+
/// generic type, return the same address, so alias the same memory.
3476+
/// The returned pointer is aligned by at least the alignment of `T`.
3477+
///
3478+
/// # Safety
3479+
///
3480+
/// The pointer is safe to dereference from the start (the returned pointer) up to the
3481+
/// size of workgroup memory that was specified when launching the current gpu-kernel.
3482+
///
3483+
/// The user must take care of synchronizing access to workgroup memory between
3484+
/// threads in a workgroup. The usual data race requirements apply.
3485+
///
3486+
/// # Other APIs
3487+
///
3488+
/// CUDA and HIP call this dynamic shared memory, shared between threads in a block.
3489+
/// OpenCL and SYCL call this local memory, shared between threads in a work-group.
3490+
/// GLSL calls this shared memory, shared between invocations in a work group.
3491+
/// DirectX calls this groupshared memory, shared between threads in a thread-group.
3492+
#[must_use = "returns a pointer that does nothing unless used"]
3493+
#[rustc_intrinsic]
3494+
#[rustc_nounwind]
3495+
#[unstable(feature = "gpu_launch_sized_workgroup_mem", issue = "135513")]
3496+
#[cfg(any(target_arch = "amdgpu", target_arch = "nvptx64"))]
3497+
pub fn gpu_launch_sized_workgroup_mem<T>() -> *mut T;
3498+
34603499
/// Loads an argument of type `T` from the `va_list` `ap` and increment the
34613500
/// argument `ap` points to.
34623501
///

src/tools/tidy/src/style.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,10 @@ fn should_ignore(line: &str) -> bool {
222222
|| static_regex!(
223223
"\\s*//@ \\!?(count|files|has|has-dir|hasraw|matches|matchesraw|snapshot)\\s.*"
224224
).is_match(line)
225+
// Matching for FileCheck checks
226+
|| static_regex!(
227+
"\\s*// [a-zA-Z0-9-_]*:\\s.*"
228+
).is_match(line)
225229
}
226230

227231
/// Returns `true` if `line` is allowed to be longer than the normal limit.

0 commit comments

Comments
 (0)