@@ -3,11 +3,12 @@ use std::ffi::CString;
33use bitflags:: Flags ;
44use llvm:: Linkage :: * ;
55use rustc_abi:: Align ;
6+ use rustc_codegen_ssa:: MemFlags ;
67use rustc_codegen_ssa:: common:: TypeKind ;
78use rustc_codegen_ssa:: mir:: operand:: { OperandRef , OperandValue } ;
89use rustc_codegen_ssa:: traits:: { BaseTypeCodegenMethods , BuilderMethods } ;
910use rustc_middle:: bug;
10- use rustc_middle:: ty:: offload_meta:: { MappingFlags , OffloadMetadata } ;
11+ use rustc_middle:: ty:: offload_meta:: { MappingFlags , OffloadMetadata , OffloadSize } ;
1112
1213use crate :: builder:: Builder ;
1314use crate :: common:: CodegenCx ;
@@ -450,7 +451,15 @@ pub(crate) fn gen_define_handling<'ll>(
450451 // FIXME(offload): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
451452 let transfer_kernel = vec ! [ MappingFlags :: TARGET_PARAM . bits( ) ; transfer_to. len( ) ] ;
452453
453- let offload_sizes = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{symbol}" ) , & sizes) ;
454+ let actual_sizes = sizes
455+ . iter ( )
456+ . map ( |s| match s {
457+ OffloadSize :: Static ( sz) => * sz,
458+ OffloadSize :: Dynamic => 0 ,
459+ } )
460+ . collect :: < Vec < _ > > ( ) ;
461+ let offload_sizes =
462+ add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{symbol}" ) , & actual_sizes) ;
454463 let memtransfer_begin =
455464 add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{symbol}.begin" ) , & transfer_to) ;
456465 let memtransfer_kernel =
@@ -499,9 +508,6 @@ pub(crate) fn gen_define_handling<'ll>(
499508 region_id,
500509 } ;
501510
502- // FIXME(Sa4dUs): use this global for constant offload sizes
503- cx. add_compiler_used_global ( result. offload_sizes ) ;
504-
505511 cx. offload_kernel_cache . borrow_mut ( ) . insert ( symbol, result) ;
506512
507513 result
@@ -535,6 +541,15 @@ pub(crate) fn scalar_width<'ll>(cx: &'ll SimpleCx<'_>, ty: &'ll Type) -> u64 {
535541 }
536542}
537543
544+ fn get_runtime_size < ' ll , ' tcx > (
545+ _cx : & CodegenCx < ' ll , ' tcx > ,
546+ _val : & ' ll Value ,
547+ _meta : & OffloadMetadata ,
548+ ) -> & ' ll Value {
549+ // FIXME(Sa4dUs): handle dynamic-size data (e.g. slices)
550+ bug ! ( "offload does not support dynamic sizes yet" ) ;
551+ }
552+
538553// For each kernel *call*, we now use some of our previous declared globals to move data to and from
539554// the gpu. For now, we only handle the data transfer part of it.
540555// If two consecutive kernels use the same memory, we still move it to the host and back to the gpu.
@@ -564,15 +579,17 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
564579) {
565580 let cx = builder. cx ;
566581 let OffloadKernelGlobals {
582+ offload_sizes,
567583 memtransfer_begin,
568584 memtransfer_kernel,
569585 memtransfer_end,
570586 region_id,
571- ..
572587 } = offload_data;
573588 let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
574589 offload_dims;
575590
591+ let has_dynamic = metadata. iter ( ) . any ( |m| matches ! ( m. payload_size, OffloadSize :: Dynamic ) ) ;
592+
576593 let tgt_decl = offload_globals. launcher_fn ;
577594 let tgt_target_kernel_ty = offload_globals. launcher_ty ;
578595
@@ -596,7 +613,24 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
596613 let a2 = builder. direct_alloca ( ty, Align :: EIGHT , ".offload_ptrs" ) ;
597614 // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
598615 let ty2 = cx. type_array ( cx. type_i64 ( ) , num_args) ;
599- let a4 = builder. direct_alloca ( ty2, Align :: EIGHT , ".offload_sizes" ) ;
616+
617+ let a4 = if has_dynamic {
618+ let alloc = builder. direct_alloca ( ty2, Align :: EIGHT , ".offload_sizes" ) ;
619+
620+ builder. memcpy (
621+ alloc,
622+ Align :: EIGHT ,
623+ offload_sizes,
624+ Align :: EIGHT ,
625+ cx. get_const_i64 ( 8 * args. len ( ) as u64 ) ,
626+ MemFlags :: empty ( ) ,
627+ None ,
628+ ) ;
629+
630+ alloc
631+ } else {
632+ offload_sizes
633+ } ;
600634
601635 //%kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
602636 let a5 = builder. direct_alloca ( tgt_kernel_decl, Align :: EIGHT , "kernel_args" ) ;
@@ -648,9 +682,12 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
648682 builder. store ( vals[ i as usize ] , gep1, Align :: EIGHT ) ;
649683 let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, idx] ) ;
650684 builder. store ( geps[ i as usize ] , gep2, Align :: EIGHT ) ;
651- let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, idx] ) ;
652- // FIXME(offload): write an offload frontend and handle arbitrary types.
653- builder. store ( cx. get_const_i64 ( metadata[ i as usize ] . payload_size ) , gep3, Align :: EIGHT ) ;
685+
686+ if matches ! ( metadata[ i as usize ] . payload_size, OffloadSize :: Dynamic ) {
687+ let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, idx] ) ;
688+ let size_val = get_runtime_size ( cx, args[ i as usize ] , & metadata[ i as usize ] ) ;
689+ builder. store ( size_val, gep3, Align :: EIGHT ) ;
690+ }
654691 }
655692
656693 // For now we have a very simplistic indexing scheme into our
0 commit comments