1- use crate :: ast:: { AtomicTy , BaseExec , DataTy , DataTyKind , FunDef , Memory , Nat , NatCtx , ScalarTy , Ty , TyKind } ;
1+ use crate :: ast:: { AtomicTy , BaseExec , DataTy , DataTyKind , FunDef , Memory , Nat , NatCtx , Ownership , ScalarTy , Ty , TyKind } ;
22use melior:: {
33 dialect:: func,
44 ir:: {
@@ -23,35 +23,6 @@ fn nat_to_dimension(nat: &Nat) -> String {
2323 }
2424}
2525
26- /// Helper function to create HACC attributes for GPU functions
27- /// This function creates proper MLIR attribute objects for HACC attributes.
28- /// Note: This requires the HACC dialect to be registered in the MLIR context.
29- fn create_hacc_attributes < ' c > ( context : & ' c Context ) -> Vec < ( Identifier < ' c > , Attribute < ' c > ) > {
30- // Create HACC entry attribute (hacc.entry) - this is a unit attribute
31- let entry_attr = Attribute :: parse ( context, "unit" )
32- . expect ( "Failed to create HACC entry attribute" ) ;
33-
34- // Create HACC function type attribute (hacc.function_kind = DEVICE)
35- // Note: This will fail if the HACC dialect is not registered in the context
36- let func_type_attr = Attribute :: parse ( context, "#hacc.function_kind<DEVICE>" )
37- . expect ( "Failed to create HACC function type attribute - ensure HACC dialect is registered" ) ;
38-
39- vec ! [
40- ( Identifier :: new( context, "hacc.entry" ) , entry_attr) ,
41- ( Identifier :: new( context, "hacc.function_kind" ) , func_type_attr) ,
42- ]
43- }
44-
45- /// Helper function to generate HACC attributes string for GPU functions
46- /// This function generates the MLIR string representation of HACC attributes
47- /// for use in string-based MLIR generation. The proper attribute objects
48- /// are available through create_hacc_attributes() when the HACC dialect is registered.
49- fn generate_hacc_attributes_string ( context : & Context ) -> String {
50- // Generate the string representation directly since we know the exact format
51- // The create_hacc_attributes() function is available for when the HACC dialect is registered
52- " attributes {hacc.entry, hacc.function_kind = #hacc.function_kind<DEVICE>}" . to_string ( )
53- }
54-
5526/// Helper function to convert ScalarTy to MLIR Type
5627fn scalar_ty_to_mlir < ' c > ( scalar_ty : & ScalarTy , context : & ' c Context ) -> Type < ' c > {
5728 match scalar_ty {
@@ -505,73 +476,6 @@ fn collect_parameter_usage(fun: &crate::ast::FunDef) -> std::collections::HashMa
505476 param_usage
506477}
507478
508- /// Collect which parameters are referenced in the function body (legacy function for compatibility)
509- fn collect_used_parameters ( fun : & crate :: ast:: FunDef ) -> std:: collections:: HashSet < String > {
510- use crate :: ast:: { Expr , ExprKind , PlaceExprKind } ;
511- use std:: collections:: HashSet ;
512-
513- let mut used_params = HashSet :: new ( ) ;
514-
515- fn walk_expr ( expr : & Expr , used_params : & mut HashSet < String > ) {
516- match & expr. expr {
517- ExprKind :: PlaceExpr ( place_expr) => {
518- if let PlaceExprKind :: Ident ( ident) = & place_expr. pl_expr {
519- used_params. insert ( ident. name . to_string ( ) ) ;
520- }
521- }
522- ExprKind :: BinOp ( _, lhs, rhs) => {
523- walk_expr ( lhs, used_params) ;
524- walk_expr ( rhs, used_params) ;
525- }
526- ExprKind :: Let ( _, _, value_expr) => {
527- walk_expr ( value_expr, used_params) ;
528- }
529- ExprKind :: Seq ( exprs) => {
530- for expr in exprs {
531- walk_expr ( expr, used_params) ;
532- }
533- }
534- ExprKind :: Assign ( place_expr, value_expr) => {
535- if let PlaceExprKind :: Ident ( ident) = & place_expr. pl_expr {
536- used_params. insert ( ident. name . to_string ( ) ) ;
537- }
538- walk_expr ( value_expr, used_params) ;
539- }
540- ExprKind :: App ( _, _, args) => {
541- for arg in args {
542- walk_expr ( arg, used_params) ;
543- }
544- }
545- ExprKind :: IfElse ( cond, case_true, case_false) => {
546- walk_expr ( cond, used_params) ;
547- walk_expr ( case_true, used_params) ;
548- walk_expr ( case_false, used_params) ;
549- }
550- ExprKind :: If ( cond, case_true) => {
551- walk_expr ( cond, used_params) ;
552- walk_expr ( case_true, used_params) ;
553- }
554- ExprKind :: ForNat ( _, _, body) => {
555- walk_expr ( body, used_params) ;
556- }
557- ExprKind :: Ref ( _, _, place_expr) => {
558- if let PlaceExprKind :: Ident ( ident) = & place_expr. pl_expr {
559- used_params. insert ( ident. name . to_string ( ) ) ;
560- }
561- }
562- ExprKind :: Unsafe ( expr) => {
563- walk_expr ( expr, used_params) ;
564- }
565- _ => {
566- // Other expression types don't contain variable references
567- }
568- }
569- }
570-
571- walk_expr ( & fun. body . body , & mut used_params) ;
572- used_params
573- }
574-
575479/// Generate body operations for GPU functions
576480fn generate_body_operations (
577481 fun : & crate :: ast:: FunDef ,
@@ -680,6 +584,18 @@ fn generate_body_operations(
680584 // Find the parameter declaration to get its type and index
681585 if let Some ( ( param_idx, param_decl) ) = fun. param_decls . iter ( ) . enumerate ( ) . find ( |( _, p) | p. ident . name == ident. name ) {
682586 if let Some ( param_ty) = & param_decl. ty {
587+ // Check if we're assigning to a reference - if so, it must be unique
588+ if let TyKind :: Data ( data_ty) = & param_ty. ty {
589+ if let DataTyKind :: Ref ( ref_dty) = & data_ty. dty {
590+ if ref_dty. own != Ownership :: Uniq {
591+ panic ! (
592+ "Assignment to non-unique reference is not allowed. Expected unique reference, found {:?}" ,
593+ ref_dty. own
594+ ) ;
595+ }
596+ }
597+ }
598+
683599 // Generate the target parameter type (should be gm address space)
684600 let target_type = get_mlir_type_string_with_address_space ( param_ty, context) ;
685601
@@ -856,7 +772,7 @@ fn generate_function_with_body(fun: &crate::ast::FunDef, context: &Context) -> S
856772 // TODO: When HACC dialect is registered in the MLIR context, replace this with:
857773 // let hacc_attributes = create_hacc_attributes(context);
858774 // and use the attributes with MLIR operation builders instead of string generation
859- signature. push_str ( & generate_hacc_attributes_string ( context ) ) ;
775+ signature. push_str ( " attributes {hacc.entry, hacc.function_kind = #hacc.function_kind<DEVICE>}" ) ;
860776 }
861777
862778 signature. push_str ( " {\n " ) ;
0 commit comments