Skip to content

Commit 020f669

Browse files
committed
Adding a new, workign offload_args intrinsic, which only maps arguments, but calls some host code with device ptrs
1 parent 4c3310a commit 020f669

5 files changed

Lines changed: 167 additions & 32 deletions

File tree

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 88 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use rustc_middle::bug;
88
use rustc_middle::ty::offload_meta::OffloadMetadata;
99

1010
use crate::builder::Builder;
11-
use crate::common::CodegenCx;
11+
use crate::common::{AsCCharPtr, CodegenCx};
1212
use crate::llvm::AttributePlace::Function;
1313
use crate::llvm::{self, Linkage, Type, Value};
1414
use crate::{SimpleCx, attributes};
@@ -288,7 +288,7 @@ pub(crate) struct OffloadKernelGlobals<'ll> {
288288
pub offload_sizes: &'ll llvm::Value,
289289
pub memtransfer_types: &'ll llvm::Value,
290290
pub region_id: &'ll llvm::Value,
291-
pub offload_entry: &'ll llvm::Value,
291+
pub offload_entry: Option<&'ll llvm::Value>,
292292
}
293293

294294
fn gen_tgt_data_mappers<'ll>(
@@ -359,16 +359,19 @@ pub(crate) fn gen_define_handling<'ll>(
359359
types: &[&'ll Type],
360360
symbol: String,
361361
offload_globals: &OffloadGlobals<'ll>,
362+
host: bool,
362363
) -> OffloadKernelGlobals<'ll> {
363364
if let Some(entry) = cx.offload_kernel_cache.borrow().get(&symbol) {
364365
return *entry;
365366
}
366367

367368
let offload_entry_ty = offload_globals.offload_entry_ty;
368369

370+
let mut arg_iter = types.iter().zip(metadata);
371+
369372
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
370373
// reference) types.
371-
let ptr_meta = types.iter().zip(metadata).filter_map(|(&x, meta)| match cx.type_kind(x) {
374+
let ptr_meta = arg_iter.filter_map(|(&x, meta)| match cx.type_kind(x) {
372375
rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta),
373376
_ => None,
374377
});
@@ -393,29 +396,34 @@ pub(crate) fn gen_define_handling<'ll>(
393396
let initializer = cx.get_const_i8(0);
394397
let region_id = add_global(&cx, &name, initializer, WeakAnyLinkage);
395398

396-
let c_entry_name = CString::new(symbol.clone()).unwrap();
397-
let c_val = c_entry_name.as_bytes_with_nul();
398-
let offload_entry_name = format!(".offloading.entry_name.{symbol}");
399-
400-
let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
401-
let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
402-
llvm::set_alignment(llglobal, Align::ONE);
403-
llvm::set_section(llglobal, c".llvm.rodata.offloading");
404-
405-
let name = format!(".offloading.entry.{symbol}");
406-
407-
// See the __tgt_offload_entry documentation above.
408-
let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
409-
410-
let initializer = crate::common::named_struct(offload_entry_ty, &elems);
411-
let c_name = CString::new(name).unwrap();
412-
let offload_entry = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
413-
llvm::set_global_constant(offload_entry, true);
414-
llvm::set_linkage(offload_entry, WeakAnyLinkage);
415-
llvm::set_initializer(offload_entry, initializer);
416-
llvm::set_alignment(offload_entry, Align::EIGHT);
417-
let c_section_name = CString::new("llvm_offload_entries").unwrap();
418-
llvm::set_section(offload_entry, &c_section_name);
399+
let offload_entry = if !host {
400+
let c_entry_name = CString::new(symbol.clone()).unwrap();
401+
let c_val = c_entry_name.as_bytes_with_nul();
402+
let offload_entry_name = format!(".offloading.entry_name.{symbol}");
403+
404+
let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
405+
let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
406+
llvm::set_alignment(llglobal, Align::ONE);
407+
llvm::set_section(llglobal, c".llvm.rodata.offloading");
408+
409+
let name = format!(".offloading.entry.{symbol}");
410+
411+
// See the __tgt_offload_entry documentation above.
412+
let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
413+
414+
let initializer = crate::common::named_struct(offload_entry_ty, &elems);
415+
let c_name = CString::new(name).unwrap();
416+
let offload_entry = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
417+
llvm::set_global_constant(offload_entry, true);
418+
llvm::set_linkage(offload_entry, WeakAnyLinkage);
419+
llvm::set_initializer(offload_entry, initializer);
420+
llvm::set_alignment(offload_entry, Align::EIGHT);
421+
let c_section_name = CString::new("llvm_offload_entries").unwrap();
422+
llvm::set_section(offload_entry, &c_section_name);
423+
Some(offload_entry)
424+
} else {
425+
None
426+
};
419427

420428
let result =
421429
OffloadKernelGlobals { offload_sizes, memtransfer_types, region_id, offload_entry };
@@ -467,7 +475,12 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
467475
metadata: &[OffloadMetadata],
468476
offload_globals: &OffloadGlobals<'ll>,
469477
offload_dims: &OffloadKernelDims<'ll>,
478+
host: bool,
479+
host_fnc_name: String,
480+
host_llfn: &'ll Value,
481+
host_llty: &'ll Type,
470482
) {
483+
dbg!(&host_llfn);
471484
let cx = builder.cx;
472485
let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } =
473486
offload_data;
@@ -490,7 +503,8 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
490503

491504
// FIXME(Sa4dUs): dummy loads are a temp workaround, we should find a proper way to prevent these
492505
// variables from being optimized away
493-
for val in [offload_sizes, offload_entry] {
506+
for val in [offload_sizes] {
507+
//for val in [offload_sizes, offload_entry] {
494508
unsafe {
495509
let dummy = llvm::LLVMBuildLoad2(
496510
&builder.llbuilder,
@@ -623,14 +637,59 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
623637
let args = vec![
624638
s_ident_t,
625639
// FIXME(offload) give users a way to select which GPU to use.
640+
//cx.get_const_i64(0), // MAX == -1.
626641
cx.get_const_i64(u64::MAX), // MAX == -1.
627642
num_workgroups,
628643
threads_per_block,
629644
region_id,
630645
a5,
631646
];
632-
builder.call(tgt_target_kernel_ty, None, None, tgt_decl, &args, None, None);
633-
// %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
647+
if host {
648+
//let new_args = vec![];
649+
//dbg!(&args);// were overwritten above, now hold omp stuff (shadowed)
650+
dbg!(&vals);
651+
dbg!(&types);
652+
//dbg!(&fn_ty);
653+
//let host_llty = cx.llvm_ty(host_llfn);
654+
dbg!(&host_fnc_name);
655+
dbg!(cx.get_defined_value(&host_fnc_name));
656+
//let user_fn = unsafe {
657+
// llvm::LLVMRustGetOrInsertFunction(
658+
// builder.llmod,
659+
// host_fnc_name.as_c_char_ptr(),
660+
// host_fnc_name.len(),
661+
// fn_ty,
662+
// )
663+
//};
664+
//dbg!(&user_fn);
665+
// void *omp_get_mapped_ptr(void *ptr, int device_num);
666+
667+
let fn_name = "omp_get_mapped_ptr";
668+
let ty2: &'ll Type = cx.type_func(&[cx.type_ptr(), cx.type_i32()], cx.type_ptr());
669+
let mapper_fn = unsafe {
670+
llvm::LLVMRustGetOrInsertFunction(
671+
builder.llmod,
672+
fn_name.as_c_char_ptr(),
673+
fn_name.len(),
674+
ty2,
675+
)
676+
};
677+
678+
let mut device_vals = Vec::with_capacity(vals.len());
679+
let device_num = cx.get_const_i32(0);
680+
for arg in vals {
681+
dbg!(&mapper_fn);
682+
dbg!(&ty2);
683+
let device_arg =
684+
builder.call(ty2, None, None, mapper_fn, &[arg, device_num], None, None);
685+
device_vals.push(device_arg);
686+
}
687+
builder.call(host_llty, None, None, host_llfn, &device_vals, None, None);
688+
//builder.call(fn_ty, None, None, host_llfn, &vals, None, None);
689+
} else {
690+
builder.call(tgt_target_kernel_ty, None, None, tgt_decl, &args, None, None);
691+
// %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
692+
}
634693

635694
// Step 4)
636695
let geps = get_geps(builder, ty, ty2, a1, a2, a4);

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,19 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
215215
let _ = tcx.dcx().emit_almost_fatal(OffloadWithoutFatLTO);
216216
}
217217

218-
codegen_offload(self, tcx, instance, args);
218+
codegen_offload(self, tcx, instance, args, false);
219+
return Ok(());
220+
}
221+
sym::offload_args => {
222+
if tcx.sess.opts.unstable_opts.offload.is_empty() {
223+
let _ = tcx.dcx().emit_almost_fatal(OffloadWithoutEnable);
224+
}
225+
226+
if tcx.sess.lto() != rustc_session::config::Lto::Fat {
227+
let _ = tcx.dcx().emit_almost_fatal(OffloadWithoutFatLTO);
228+
}
229+
230+
codegen_offload(self, tcx, instance, args, true);
219231
return Ok(());
220232
}
221233
sym::is_val_statically_known => {
@@ -1362,8 +1374,11 @@ fn codegen_offload<'ll, 'tcx>(
13621374
tcx: TyCtxt<'tcx>,
13631375
instance: ty::Instance<'tcx>,
13641376
args: &[OperandRef<'tcx, &'ll Value>],
1377+
host: bool,
13651378
) {
1379+
dbg!(&args[0]);
13661380
let cx = bx.cx;
1381+
dbg!(&instance);
13671382
let fn_args = instance.args;
13681383

13691384
let (target_id, target_args) = match fn_args.into_type_list(tcx)[0].kind() {
@@ -1383,13 +1398,34 @@ fn codegen_offload<'ll, 'tcx>(
13831398
return;
13841399
}
13851400
};
1401+
let llfn = cx.get_fn(fn_target);
1402+
dbg!(&llfn);
1403+
dbg!(&fn_target);
13861404

13871405
let offload_dims = OffloadKernelDims::from_operands(bx, &args[1], &args[2]);
13881406
let args = get_args_from_tuple(bx, args[3], fn_target);
1407+
dbg!(&args);
13891408
let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target, LOCAL_CRATE);
1409+
dbg!(&target_symbol);
13901410

13911411
let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder();
1412+
dbg!(&sig);
13921413
let inputs = sig.inputs();
1414+
// [compiler/rustc_codegen_llvm/src/intrinsic.rs:1405:5] &args = [
1415+
// (ptr: %36 = load ptr, ptr %_14, align 8),
1416+
// (ptr: %38 = load ptr, ptr %37, align 8),
1417+
// (ptr: %40 = load ptr, ptr %39, align 8),
1418+
// (ptr: %42 = load ptr, ptr %41, align 8),
1419+
// (ptr: %44 = load ptr, ptr %43, align 8),
1420+
// (ptr: %46 = load ptr, ptr %45, align 8),
1421+
// (ptr: %48 = load ptr, ptr %47, align 8),
1422+
// (ptr: %50 = load ptr, ptr %49, align 8),
1423+
// (ptr: %52 = load ptr, ptr %51, align 8),
1424+
// (ptr: %54 = load ptr, ptr %53, align 8),
1425+
//]
1426+
//[compiler/rustc_codegen_llvm/src/intrinsic.rs:1407:5] &target_symbol = "rocblas_sgemv_wrapper"
1427+
//[compiler/rustc_codegen_llvm/src/intrinsic.rs:1410:5] &sig = fn(*const i32, *const i32, *const f32, *const f32, *const i32, *const f32, *const i32, *const f32, *mut f32, *const i32)
1428+
//)
13931429

13941430
let metadata = inputs.iter().map(|ty| OffloadMetadata::from_ty(tcx, *ty)).collect::<Vec<_>>();
13951431

@@ -1403,8 +1439,26 @@ fn codegen_offload<'ll, 'tcx>(
14031439
return;
14041440
}
14051441
};
1406-
let offload_data = gen_define_handling(&cx, &metadata, &types, target_symbol, offload_globals);
1407-
gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals, &offload_dims);
1442+
let instance = rustc_middle::ty::Instance::mono(tcx, fn_target.def_id());
1443+
let fn_abi = cx.fn_abi_of_instance(instance, tcx.mk_type_list(&[]));
1444+
let host_fn_ty = fn_abi.llvm_type(cx);
1445+
dbg!(&host_fn_ty);
1446+
1447+
let offload_data =
1448+
gen_define_handling(&cx, &metadata, &types, target_symbol.clone(), offload_globals, host);
1449+
gen_call_handling(
1450+
bx,
1451+
&offload_data,
1452+
&args,
1453+
&types,
1454+
&metadata,
1455+
offload_globals,
1456+
&offload_dims,
1457+
host,
1458+
target_symbol,
1459+
llfn,
1460+
host_fn_ty,
1461+
);
14081462
}
14091463

14101464
fn get_args_from_tuple<'ll, 'tcx>(

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi
164164
| sym::mul_with_overflow
165165
| sym::needs_drop
166166
| sym::offload
167+
| sym::offload_args
167168
| sym::offset_of
168169
| sym::overflow_checks
169170
| sym::powf16
@@ -326,6 +327,17 @@ pub(crate) fn check_intrinsic_type(
326327
],
327328
param(2),
328329
),
330+
sym::offload_args => (
331+
3,
332+
0,
333+
vec![
334+
param(0),
335+
Ty::new_array_with_const_len(tcx, tcx.types.u32, Const::from_target_usize(tcx, 3)),
336+
Ty::new_array_with_const_len(tcx, tcx.types.u32, Const::from_target_usize(tcx, 3)),
337+
param(1),
338+
],
339+
param(2),
340+
),
329341
sym::offset => (2, 0, vec![param(0), param(1)], param(0)),
330342
sym::arith_offset => (
331343
1,

compiler/rustc_span/src/symbol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,6 +1596,7 @@ symbols! {
15961596
of,
15971597
off,
15981598
offload,
1599+
offload_args,
15991600
offset,
16001601
offset_of,
16011602
offset_of_enum,

library/core/src/intrinsics/mod.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3421,6 +3421,15 @@ pub const fn offload<F, T: crate::marker::Tuple, R>(
34213421
args: T,
34223422
) -> R;
34233423

3424+
#[rustc_nounwind]
3425+
#[rustc_intrinsic]
3426+
pub const fn offload_args<F, T: crate::marker::Tuple, R>(
3427+
f: F,
3428+
workgroup_dim: [u32; 3],
3429+
thread_dim: [u32; 3],
3430+
args: T,
3431+
) -> R;
3432+
34243433
/// Inform Miri that a given pointer definitely has a certain alignment.
34253434
#[cfg(miri)]
34263435
#[rustc_allow_const_fn_unstable(const_eval_select)]

0 commit comments

Comments
 (0)