@@ -8,7 +8,7 @@ use rustc_middle::bug;
88use rustc_middle:: ty:: offload_meta:: OffloadMetadata ;
99
1010use crate :: builder:: Builder ;
11- use crate :: common:: CodegenCx ;
11+ use crate :: common:: { AsCCharPtr , CodegenCx } ;
1212use crate :: llvm:: AttributePlace :: Function ;
1313use crate :: llvm:: { self , Linkage , Type , Value } ;
1414use crate :: { SimpleCx , attributes} ;
@@ -288,7 +288,7 @@ pub(crate) struct OffloadKernelGlobals<'ll> {
288288 pub offload_sizes : & ' ll llvm:: Value ,
289289 pub memtransfer_types : & ' ll llvm:: Value ,
290290 pub region_id : & ' ll llvm:: Value ,
291- pub offload_entry : & ' ll llvm:: Value ,
291+ pub offload_entry : Option < & ' ll llvm:: Value > ,
292292}
293293
294294fn gen_tgt_data_mappers < ' ll > (
@@ -359,16 +359,19 @@ pub(crate) fn gen_define_handling<'ll>(
359359 types : & [ & ' ll Type ] ,
360360 symbol : String ,
361361 offload_globals : & OffloadGlobals < ' ll > ,
362+ host : bool ,
362363) -> OffloadKernelGlobals < ' ll > {
363364 if let Some ( entry) = cx. offload_kernel_cache . borrow ( ) . get ( & symbol) {
364365 return * entry;
365366 }
366367
367368 let offload_entry_ty = offload_globals. offload_entry_ty ;
368369
370+ let mut arg_iter = types. iter ( ) . zip ( metadata) ;
371+
369372 // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
370373 // reference) types.
371- let ptr_meta = types . iter ( ) . zip ( metadata ) . filter_map ( |( & x, meta) | match cx. type_kind ( x) {
374+ let ptr_meta = arg_iter . filter_map ( |( & x, meta) | match cx. type_kind ( x) {
372375 rustc_codegen_ssa:: common:: TypeKind :: Pointer => Some ( meta) ,
373376 _ => None ,
374377 } ) ;
@@ -393,29 +396,34 @@ pub(crate) fn gen_define_handling<'ll>(
393396 let initializer = cx. get_const_i8 ( 0 ) ;
394397 let region_id = add_global ( & cx, & name, initializer, WeakAnyLinkage ) ;
395398
396- let c_entry_name = CString :: new ( symbol. clone ( ) ) . unwrap ( ) ;
397- let c_val = c_entry_name. as_bytes_with_nul ( ) ;
398- let offload_entry_name = format ! ( ".offloading.entry_name.{symbol}" ) ;
399-
400- let initializer = crate :: common:: bytes_in_context ( cx. llcx , c_val) ;
401- let llglobal = add_unnamed_global ( & cx, & offload_entry_name, initializer, InternalLinkage ) ;
402- llvm:: set_alignment ( llglobal, Align :: ONE ) ;
403- llvm:: set_section ( llglobal, c".llvm.rodata.offloading" ) ;
404-
405- let name = format ! ( ".offloading.entry.{symbol}" ) ;
406-
407- // See the __tgt_offload_entry documentation above.
408- let elems = TgtOffloadEntry :: new ( & cx, region_id, llglobal) ;
409-
410- let initializer = crate :: common:: named_struct ( offload_entry_ty, & elems) ;
411- let c_name = CString :: new ( name) . unwrap ( ) ;
412- let offload_entry = llvm:: add_global ( cx. llmod , offload_entry_ty, & c_name) ;
413- llvm:: set_global_constant ( offload_entry, true ) ;
414- llvm:: set_linkage ( offload_entry, WeakAnyLinkage ) ;
415- llvm:: set_initializer ( offload_entry, initializer) ;
416- llvm:: set_alignment ( offload_entry, Align :: EIGHT ) ;
417- let c_section_name = CString :: new ( "llvm_offload_entries" ) . unwrap ( ) ;
418- llvm:: set_section ( offload_entry, & c_section_name) ;
399+ let offload_entry = if !host {
400+ let c_entry_name = CString :: new ( symbol. clone ( ) ) . unwrap ( ) ;
401+ let c_val = c_entry_name. as_bytes_with_nul ( ) ;
402+ let offload_entry_name = format ! ( ".offloading.entry_name.{symbol}" ) ;
403+
404+ let initializer = crate :: common:: bytes_in_context ( cx. llcx , c_val) ;
405+ let llglobal = add_unnamed_global ( & cx, & offload_entry_name, initializer, InternalLinkage ) ;
406+ llvm:: set_alignment ( llglobal, Align :: ONE ) ;
407+ llvm:: set_section ( llglobal, c".llvm.rodata.offloading" ) ;
408+
409+ let name = format ! ( ".offloading.entry.{symbol}" ) ;
410+
411+ // See the __tgt_offload_entry documentation above.
412+ let elems = TgtOffloadEntry :: new ( & cx, region_id, llglobal) ;
413+
414+ let initializer = crate :: common:: named_struct ( offload_entry_ty, & elems) ;
415+ let c_name = CString :: new ( name) . unwrap ( ) ;
416+ let offload_entry = llvm:: add_global ( cx. llmod , offload_entry_ty, & c_name) ;
417+ llvm:: set_global_constant ( offload_entry, true ) ;
418+ llvm:: set_linkage ( offload_entry, WeakAnyLinkage ) ;
419+ llvm:: set_initializer ( offload_entry, initializer) ;
420+ llvm:: set_alignment ( offload_entry, Align :: EIGHT ) ;
421+ let c_section_name = CString :: new ( "llvm_offload_entries" ) . unwrap ( ) ;
422+ llvm:: set_section ( offload_entry, & c_section_name) ;
423+ Some ( offload_entry)
424+ } else {
425+ None
426+ } ;
419427
420428 let result =
421429 OffloadKernelGlobals { offload_sizes, memtransfer_types, region_id, offload_entry } ;
@@ -467,7 +475,12 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
467475 metadata : & [ OffloadMetadata ] ,
468476 offload_globals : & OffloadGlobals < ' ll > ,
469477 offload_dims : & OffloadKernelDims < ' ll > ,
478+ host : bool ,
479+ host_fnc_name : String ,
480+ host_llfn : & ' ll Value ,
481+ host_llty : & ' ll Type ,
470482) {
483+ dbg ! ( & host_llfn) ;
471484 let cx = builder. cx ;
472485 let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } =
473486 offload_data;
@@ -490,7 +503,8 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
490503
491504 // FIXME(Sa4dUs): dummy loads are a temp workaround, we should find a proper way to prevent these
492505 // variables from being optimized away
493- for val in [ offload_sizes, offload_entry] {
506+ for val in [ offload_sizes] {
507+ //for val in [offload_sizes, offload_entry] {
494508 unsafe {
495509 let dummy = llvm:: LLVMBuildLoad2 (
496510 & builder. llbuilder ,
@@ -623,14 +637,59 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
623637 let args = vec ! [
624638 s_ident_t,
625639 // FIXME(offload) give users a way to select which GPU to use.
640+ //cx.get_const_i64(0), // MAX == -1.
626641 cx. get_const_i64( u64 :: MAX ) , // MAX == -1.
627642 num_workgroups,
628643 threads_per_block,
629644 region_id,
630645 a5,
631646 ] ;
632- builder. call ( tgt_target_kernel_ty, None , None , tgt_decl, & args, None , None ) ;
633- // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
647+ if host {
648+ //let new_args = vec![];
649+ //dbg!(&args);// were overwritten above, now hold omp stuff (shadowed)
650+ dbg ! ( & vals) ;
651+ dbg ! ( & types) ;
652+ //dbg!(&fn_ty);
653+ //let host_llty = cx.llvm_ty(host_llfn);
654+ dbg ! ( & host_fnc_name) ;
655+ dbg ! ( cx. get_defined_value( & host_fnc_name) ) ;
656+ //let user_fn = unsafe {
657+ // llvm::LLVMRustGetOrInsertFunction(
658+ // builder.llmod,
659+ // host_fnc_name.as_c_char_ptr(),
660+ // host_fnc_name.len(),
661+ // fn_ty,
662+ // )
663+ //};
664+ //dbg!(&user_fn);
665+ // void *omp_get_mapped_ptr(void *ptr, int device_num);
666+
667+ let fn_name = "omp_get_mapped_ptr" ;
668+ let ty2: & ' ll Type = cx. type_func ( & [ cx. type_ptr ( ) , cx. type_i32 ( ) ] , cx. type_ptr ( ) ) ;
669+ let mapper_fn = unsafe {
670+ llvm:: LLVMRustGetOrInsertFunction (
671+ builder. llmod ,
672+ fn_name. as_c_char_ptr ( ) ,
673+ fn_name. len ( ) ,
674+ ty2,
675+ )
676+ } ;
677+
678+ let mut device_vals = Vec :: with_capacity ( vals. len ( ) ) ;
679+ let device_num = cx. get_const_i32 ( 0 ) ;
680+ for arg in vals {
681+ dbg ! ( & mapper_fn) ;
682+ dbg ! ( & ty2) ;
683+ let device_arg =
684+ builder. call ( ty2, None , None , mapper_fn, & [ arg, device_num] , None , None ) ;
685+ device_vals. push ( device_arg) ;
686+ }
687+ builder. call ( host_llty, None , None , host_llfn, & device_vals, None , None ) ;
688+ //builder.call(fn_ty, None, None, host_llfn, &vals, None, None);
689+ } else {
690+ builder. call ( tgt_target_kernel_ty, None , None , tgt_decl, & args, None , None ) ;
691+ // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
692+ }
634693
635694 // Step 4)
636695 let geps = get_geps ( builder, ty, ty2, a1, a2, a4) ;
0 commit comments