Skip to content

Commit 788ce48

Browse files
committed
move initialization of omp/ol runtimes into global_ctor/dtor
1 parent 1b39278 commit 788ce48

3 files changed

Lines changed: 77 additions & 31 deletions

File tree

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 69 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ pub(crate) struct OffloadGlobals<'ll> {
1818
pub launcher_fn: &'ll llvm::Value,
1919
pub launcher_ty: &'ll llvm::Type,
2020

21-
pub bin_desc: &'ll llvm::Type,
22-
21+
//pub bin_desc: &'ll llvm::Type,
2322
pub kernel_args_ty: &'ll llvm::Type,
2423

2524
pub offload_entry_ty: &'ll llvm::Type,
@@ -30,8 +29,6 @@ pub(crate) struct OffloadGlobals<'ll> {
3029

3130
pub ident_t_global: &'ll llvm::Value,
3231

33-
pub register_lib: &'ll llvm::Value,
34-
pub unregister_lib: &'ll llvm::Value,
3532
pub init_rtls: &'ll llvm::Value,
3633
}
3734

@@ -43,35 +40,92 @@ impl<'ll> OffloadGlobals<'ll> {
4340
let (begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers(cx);
4441
let ident_t_global = generate_at_one(cx);
4542

46-
let tptr = cx.type_ptr();
47-
let ti32 = cx.type_i32();
48-
let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
49-
let bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
50-
cx.set_struct_body(bin_desc, &tgt_bin_desc_ty, false);
43+
//let tptr = cx.type_ptr();
44+
//let ti32 = cx.type_i32();
45+
//let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
46+
//let bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
47+
//cx.set_struct_body(bin_desc, &tgt_bin_desc_ty, false);
5148

52-
let reg_lib_decl = cx.type_func(&[cx.type_ptr()], cx.type_void());
53-
let register_lib = declare_offload_fn(&cx, "__tgt_register_lib", reg_lib_decl);
54-
let unregister_lib = declare_offload_fn(&cx, "__tgt_unregister_lib", reg_lib_decl);
5549
let init_ty = cx.type_func(&[], cx.type_void());
5650
let init_rtls = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
5751

5852
OffloadGlobals {
5953
launcher_fn,
6054
launcher_ty,
61-
bin_desc,
55+
//bin_desc,
6256
kernel_args_ty,
6357
offload_entry_ty,
6458
begin_mapper,
6559
end_mapper,
6660
mapper_fn_ty,
6761
ident_t_global,
68-
register_lib,
69-
unregister_lib,
7062
init_rtls,
7163
}
7264
}
7365
}
7466

67+
pub(crate) fn setup<'ll>(cx: &CodegenCx<'ll, '_>) {
68+
let reg_lib_decl = cx.type_func(&[cx.type_ptr()], cx.type_void());
69+
let register_lib = declare_offload_fn(&cx, "__tgt_register_lib", reg_lib_decl);
70+
let unregister_lib = declare_offload_fn(&cx, "__tgt_unregister_lib", reg_lib_decl);
71+
72+
//pub(crate) fn create_used_variable_impl(&self, name: &'static CStr, values: &[&'ll Value]) {
73+
let i32_0 = cx.get_const_i32(0);
74+
let ptr_null = cx.const_null(cx.type_ptr());
75+
let const_struct = cx.const_struct(&[i32_0, ptr_null, ptr_null, ptr_null], false);
76+
let omp_descriptor =
77+
add_global(cx, ".omp_offloading.descriptor", const_struct, InternalLinkage);
78+
// @.omp_offloading.descriptor = internal constant %__tgt_bin_desc { i32 1, ptr @.omp_offloading.device_images, ptr @__start_llvm_offload_entries, ptr @__stop_llvm_offload_entries }
79+
// @.omp_offloading.descriptor = internal constant %__tgt_bin_desc { i32 0, ptr null, ptr null, ptr null }
80+
unsafe { llvm::LLVMDumpModule(cx.llmod()) };
81+
82+
let atexit = cx.type_func(&[cx.type_ptr()], cx.type_i32());
83+
let atexit_fn = declare_offload_fn(cx, "atexit", atexit);
84+
// declare i32 @atexit(ptr)
85+
86+
let reg_name = ".omp_offloading.descriptor_reg";
87+
let unreg_name = ".omp_offloading.descriptor_unreg";
88+
let desc_ty = cx.type_func(&[], cx.type_void());
89+
90+
let desc_reg_fn = declare_offload_fn(cx, reg_name, desc_ty);
91+
let desc_unreg_fn = declare_offload_fn(cx, unreg_name, desc_ty);
92+
llvm::set_linkage(desc_reg_fn, InternalLinkage);
93+
llvm::set_linkage(desc_unreg_fn, InternalLinkage);
94+
llvm::set_section(desc_reg_fn, c".text.startup");
95+
llvm::set_section(desc_unreg_fn, c".text.startup");
96+
97+
// define internal void @.omp_offloading.descriptor_reg() section ".text.startup" {
98+
// entry:
99+
// call void @__tgt_register_lib(ptr @.omp_offloading.descriptor)
100+
// %0 = call i32 @atexit(ptr @.omp_offloading.descriptor_unreg)
101+
// ret void
102+
// }
103+
let bb = Builder::append_block(cx, desc_reg_fn, "entry");
104+
let mut a = Builder::build(cx, bb);
105+
a.call(reg_lib_decl, None, None, register_lib, &[omp_descriptor], None, None);
106+
a.call(atexit, None, None, atexit_fn, &[desc_unreg_fn], None, None);
107+
a.ret_void();
108+
109+
// define internal void @.omp_offloading.descriptor_unreg() section ".text.startup" {
110+
// entry:
111+
// call void @__tgt_unregister_lib(ptr @.omp_offloading.descriptor)
112+
// ret void
113+
// }
114+
let bb = Builder::append_block(cx, desc_unreg_fn, "entry");
115+
let mut a = Builder::build(cx, bb);
116+
a.call(reg_lib_decl, None, None, unregister_lib, &[omp_descriptor], None, None);
117+
a.ret_void();
118+
119+
unsafe { llvm::LLVMDumpModule(cx.llmod()) };
120+
121+
let args = vec![cx.get_const_i32(101), desc_reg_fn, ptr_null];
122+
let const_struct = cx.const_struct(&args, false);
123+
let arr = cx.const_array(cx.val_ty(const_struct), &[const_struct]);
124+
let _global_ctor = add_global(cx, "llvm.global_ctors", arr, AppendingLinkage);
125+
// @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 101, ptr @.omp_offloading.descriptor_reg, ptr null }]
126+
unsafe { llvm::LLVMDumpModule(cx.llmod()) };
127+
}
128+
75129
pub(crate) struct OffloadKernelDims<'ll> {
76130
num_workgroups: &'ll Value,
77131
threads_per_block: &'ll Value,
@@ -478,9 +532,6 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
478532
let tgt_decl = offload_globals.launcher_fn;
479533
let tgt_target_kernel_ty = offload_globals.launcher_ty;
480534

481-
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
482-
let tgt_bin_desc = offload_globals.bin_desc;
483-
484535
let tgt_kernel_decl = offload_globals.kernel_args_ty;
485536
let begin_mapper_decl = offload_globals.begin_mapper;
486537
let end_mapper_decl = offload_globals.end_mapper;
@@ -504,12 +555,9 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
504555
}
505556

506557
// Step 0)
507-
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
508-
// %6 = alloca %struct.__tgt_bin_desc, align 8
509558
unsafe {
510559
llvm::LLVMRustPositionBuilderPastAllocas(&builder.llbuilder, builder.llfn());
511560
}
512-
let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
513561

514562
let ty = cx.type_array(cx.type_ptr(), num_args);
515563
// Baseptr are just the input pointer to the kernel, stored in a local alloca
@@ -527,7 +575,6 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
527575
unsafe {
528576
llvm::LLVMPositionBuilderAtEnd(&builder.llbuilder, bb);
529577
}
530-
builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
531578

532579
// Now we allocate once per function param, a copy to be passed to one of our maps.
533580
let mut vals = vec![];
@@ -539,15 +586,9 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
539586
geps.push(gep);
540587
}
541588

542-
let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
543-
let register_lib_decl = offload_globals.register_lib;
544-
let unregister_lib_decl = offload_globals.unregister_lib;
545589
let init_ty = cx.type_func(&[], cx.type_void());
546590
let init_rtls_decl = offload_globals.init_rtls;
547591

548-
// FIXME(offload): Later we want to add them to the wrapper code, rather than our main function.
549-
// call void @__tgt_register_lib(ptr noundef %6)
550-
builder.call(mapper_fn_ty, None, None, register_lib_decl, &[tgt_bin_desc_alloca], None, None);
551592
// call void @__tgt_init_all_rtls()
552593
builder.call(init_ty, None, None, init_rtls_decl, &[], None, None);
553594

@@ -644,6 +685,4 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
644685
num_args,
645686
s_ident_t,
646687
);
647-
648-
builder.call(mapper_fn_ty, None, None, unregister_lib_decl, &[tgt_bin_desc_alloca], None, None);
649688
}

compiler/rustc_codegen_llvm/src/common.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
124124
pub(crate) fn const_null(&self, t: &'ll Type) -> &'ll Value {
125125
unsafe { llvm::LLVMConstNull(t) }
126126
}
127+
128+
pub(crate) fn const_struct(&self, elts: &[&'ll Value], packed: bool) -> &'ll Value {
129+
struct_in_context(self.llcx(), elts, packed)
130+
}
127131
}
128132

129133
impl<'ll, 'tcx> ConstCodegenMethods for CodegenCx<'ll, 'tcx> {

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ use tracing::debug;
3030
use crate::abi::FnAbiLlvmExt;
3131
use crate::builder::Builder;
3232
use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call};
33-
use crate::builder::gpu_offload::{OffloadKernelDims, gen_call_handling, gen_define_handling};
33+
use crate::builder::gpu_offload::{
34+
OffloadKernelDims, gen_call_handling, gen_define_handling, setup,
35+
};
3436
use crate::context::CodegenCx;
3537
use crate::declare::declare_raw_fn;
3638
use crate::errors::{
@@ -1403,6 +1405,7 @@ fn codegen_offload<'ll, 'tcx>(
14031405
return;
14041406
}
14051407
};
1408+
setup(cx);
14061409
let offload_data = gen_define_handling(&cx, &metadata, &types, target_symbol, offload_globals);
14071410
gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals, &offload_dims);
14081411
}

0 commit comments

Comments
 (0)