Skip to content

Commit 74b581c

Browse files
committed
feat: add type checks for MLIR codegen types
1 parent 9996482 commit 74b581c

11 files changed

Lines changed: 106 additions & 148 deletions

File tree

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
fn assign_to_shared_ref<n: nat, r: prv>(
2+
a: &r shrd gpu.global [i16; 16],
3+
b: &r shrd gpu.global [i16; 16]
4+
) -[grid: gpu.grid<X<1>, X<16>>]-> () {
5+
b = a; // This should fail - cannot assign to shared reference
6+
()
7+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Vector addition kernel demonstrating Descend's safe GPU programming model
2+
// This function showcases extended borrow checking, memory safety, and execution context tracking
3+
4+
// Generic function with type parameters:
5+
// - n: nat - Natural number parameter (for array size, though not used in this specific function)
6+
// - r: prv - Provenance parameter tracking memory region/lifetime for all references
7+
fn add<n: nat, r: prv>(
8+
// Shared reference to first input vector - multiple threads can read simultaneously
9+
// Memory space: gpu.global (GPU global memory)
10+
// Ownership: shrd (shared) - prevents write-after-read data races
11+
// Type: 16-element array of 16-bit signed integers
12+
a: &r shrd gpu.global [i16; 16],
13+
14+
// Shared reference to second input vector - multiple threads can read simultaneously
15+
// Same memory space and ownership constraints as 'a'
16+
b: &r shrd gpu.global [i16; 16],
17+
18+
// ERROR: This parameter should be declared as 'unq' (unique) instead of 'shrd' (shared)
19+
// because it's used for assignment (c = a + b). Shared references are read-only and
20+
// prevent data races by allowing multiple concurrent readers. Unique references are
21+
// required for write operations to ensure exclusive access and prevent race conditions.
22+
// The Descend compiler will detect this ownership violation and fail compilation.
23+
c: &r shrd gpu.global [i16; 16]
24+
25+
// Execution context specification - defines how this function runs on GPU hardware
26+
// - grid: gpu.grid<X<1>, X<16>> - GPU execution grid with 1 block containing 16 threads
27+
// - The type system ensures GPU memory is only accessed in GPU execution contexts
28+
// - Prevents invalid cross-device memory accesses (CPU accessing GPU memory)
29+
) -[grid: gpu.grid<X<1>, X<16>>]-> () {
30+
31+
// Vector addition operation - element-wise addition of arrays
32+
// The compiler generates safe parallel code that:
33+
// 1. Loads data from global memory to local memory for each thread
34+
// 2. Performs vectorized addition using HIVM dialect operations
35+
// 3. Stores results back to global memory safely
36+
// The ownership system ensures this operation is race-free
37+
//
38+
// LAZY LOADING: Descend's compiler implements lazy loading strategies:
39+
// - Memory loads are deferred until actually needed by computation
40+
// - The HIVM dialect generates 'hivm.hir.load' operations that load from
41+
// global memory (gm) to local memory (ub) only when data is accessed
42+
// - This minimizes memory bandwidth usage and improves cache efficiency
43+
// - The type system ensures loads happen in the correct execution context
44+
// - Shared references enable read-only access without unnecessary copies
45+
c = a + b;
46+
47+
// Unit return value - indicates successful completion
48+
// In MLIR, this becomes a 'return' operation
49+
()
50+
}

src/codegen/mlir/builder/control_flow.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ where
100100
let (then_region, true_value) = build_branch_region(case_true, ctx, build_expr)?;
101101

102102
// Build the else region
103-
let (else_region, false_value) = build_branch_region(case_false, ctx, build_expr)?;
103+
let (else_region, _false_value) = build_branch_region(case_false, ctx, build_expr)?;
104104

105105
// Determine result types based on whether branches produce values
106106
let result_types: Vec<Type> = if let Some(val) = true_value {

src/codegen/mlir/builder/place.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use super::super::error::MlirError;
44
use super::context::MlirContext;
55
use super::expr::build_expr;
66
use crate::ast as desc;
7+
use crate::ast::{DataTyKind, Ownership};
78

89
/// Build a place expression (variable lookup)
910
pub fn build_place_expr<'ctx, 'a, 'b>(
@@ -57,6 +58,20 @@ where
5758
{
5859
use desc::PlaceExprKind;
5960

61+
// Check if we're assigning to a reference - if so, it must be unique
62+
if let Some(ty) = &place_expr.ty {
63+
if let desc::TyKind::Data(data_ty) = &ty.ty {
64+
if let DataTyKind::Ref(ref_dty) = &data_ty.dty {
65+
if ref_dty.own != Ownership::Uniq {
66+
return Err(MlirError::General(format!(
67+
"Assignment to non-unique reference is not allowed. Expected unique reference, found {:?}",
68+
ref_dty.own
69+
)));
70+
}
71+
}
72+
}
73+
}
74+
6075
// Evaluate the right-hand side value
6176
let value = build_expr(value_expr, ctx)?
6277
.ok_or_else(|| MlirError::General("Missing value for assignment".to_string()))?;

src/codegen/mlir/mod.rs

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,22 +72,6 @@ fn build_module_internal(comp_unit: &CompilUnit) -> Result<String, MlirError> {
7272
Ok(builder.module().as_operation().to_string())
7373
}
7474

75-
pub fn gen(comp_unit: &CompilUnit, _idx_checks: bool) -> String {
76-
// Check if we need HIVM address spaces
77-
if needs_hivm_address_space(comp_unit) {
78-
to_mlir::types::generate_mlir_string_with_hivm(comp_unit)
79-
} else {
80-
// Use internal helper, but handle errors by falling back to string generation
81-
match build_module_internal(comp_unit) {
82-
Ok(mlir_string) => mlir_string,
83-
Err(_) => {
84-
// Fallback to string generation if internal building fails
85-
to_mlir::types::generate_mlir_string_with_hivm(comp_unit)
86-
}
87-
}
88-
}
89-
}
90-
9175
pub fn gen_checked(comp_unit: &CompilUnit, _idx_checks: bool) -> Result<String, String> {
9276
// Check if we need HIVM address spaces
9377
if needs_hivm_address_space(comp_unit) {

src/codegen/mlir/to_mlir/types.rs

Lines changed: 14 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::ast::{AtomicTy, BaseExec, DataTy, DataTyKind, FunDef, Memory, Nat, NatCtx, ScalarTy, Ty, TyKind};
1+
use crate::ast::{AtomicTy, BaseExec, DataTy, DataTyKind, FunDef, Memory, Nat, NatCtx, Ownership, ScalarTy, Ty, TyKind};
22
use melior::{
33
dialect::func,
44
ir::{
@@ -23,35 +23,6 @@ fn nat_to_dimension(nat: &Nat) -> String {
2323
}
2424
}
2525

26-
/// Helper function to create HACC attributes for GPU functions
27-
/// This function creates proper MLIR attribute objects for HACC attributes.
28-
/// Note: This requires the HACC dialect to be registered in the MLIR context.
29-
fn create_hacc_attributes<'c>(context: &'c Context) -> Vec<(Identifier<'c>, Attribute<'c>)> {
30-
// Create HACC entry attribute (hacc.entry) - this is a unit attribute
31-
let entry_attr = Attribute::parse(context, "unit")
32-
.expect("Failed to create HACC entry attribute");
33-
34-
// Create HACC function type attribute (hacc.function_kind = DEVICE)
35-
// Note: This will fail if the HACC dialect is not registered in the context
36-
let func_type_attr = Attribute::parse(context, "#hacc.function_kind<DEVICE>")
37-
.expect("Failed to create HACC function type attribute - ensure HACC dialect is registered");
38-
39-
vec![
40-
(Identifier::new(context, "hacc.entry"), entry_attr),
41-
(Identifier::new(context, "hacc.function_kind"), func_type_attr),
42-
]
43-
}
44-
45-
/// Helper function to generate HACC attributes string for GPU functions
46-
/// This function generates the MLIR string representation of HACC attributes
47-
/// for use in string-based MLIR generation. The proper attribute objects
48-
/// are available through create_hacc_attributes() when the HACC dialect is registered.
49-
fn generate_hacc_attributes_string(context: &Context) -> String {
50-
// Generate the string representation directly since we know the exact format
51-
// The create_hacc_attributes() function is available for when the HACC dialect is registered
52-
" attributes {hacc.entry, hacc.function_kind = #hacc.function_kind<DEVICE>}".to_string()
53-
}
54-
5526
/// Helper function to convert ScalarTy to MLIR Type
5627
fn scalar_ty_to_mlir<'c>(scalar_ty: &ScalarTy, context: &'c Context) -> Type<'c> {
5728
match scalar_ty {
@@ -505,73 +476,6 @@ fn collect_parameter_usage(fun: &crate::ast::FunDef) -> std::collections::HashMa
505476
param_usage
506477
}
507478

508-
/// Collect which parameters are referenced in the function body (legacy function for compatibility)
509-
fn collect_used_parameters(fun: &crate::ast::FunDef) -> std::collections::HashSet<String> {
510-
use crate::ast::{Expr, ExprKind, PlaceExprKind};
511-
use std::collections::HashSet;
512-
513-
let mut used_params = HashSet::new();
514-
515-
fn walk_expr(expr: &Expr, used_params: &mut HashSet<String>) {
516-
match &expr.expr {
517-
ExprKind::PlaceExpr(place_expr) => {
518-
if let PlaceExprKind::Ident(ident) = &place_expr.pl_expr {
519-
used_params.insert(ident.name.to_string());
520-
}
521-
}
522-
ExprKind::BinOp(_, lhs, rhs) => {
523-
walk_expr(lhs, used_params);
524-
walk_expr(rhs, used_params);
525-
}
526-
ExprKind::Let(_, _, value_expr) => {
527-
walk_expr(value_expr, used_params);
528-
}
529-
ExprKind::Seq(exprs) => {
530-
for expr in exprs {
531-
walk_expr(expr, used_params);
532-
}
533-
}
534-
ExprKind::Assign(place_expr, value_expr) => {
535-
if let PlaceExprKind::Ident(ident) = &place_expr.pl_expr {
536-
used_params.insert(ident.name.to_string());
537-
}
538-
walk_expr(value_expr, used_params);
539-
}
540-
ExprKind::App(_, _, args) => {
541-
for arg in args {
542-
walk_expr(arg, used_params);
543-
}
544-
}
545-
ExprKind::IfElse(cond, case_true, case_false) => {
546-
walk_expr(cond, used_params);
547-
walk_expr(case_true, used_params);
548-
walk_expr(case_false, used_params);
549-
}
550-
ExprKind::If(cond, case_true) => {
551-
walk_expr(cond, used_params);
552-
walk_expr(case_true, used_params);
553-
}
554-
ExprKind::ForNat(_, _, body) => {
555-
walk_expr(body, used_params);
556-
}
557-
ExprKind::Ref(_, _, place_expr) => {
558-
if let PlaceExprKind::Ident(ident) = &place_expr.pl_expr {
559-
used_params.insert(ident.name.to_string());
560-
}
561-
}
562-
ExprKind::Unsafe(expr) => {
563-
walk_expr(expr, used_params);
564-
}
565-
_ => {
566-
// Other expression types don't contain variable references
567-
}
568-
}
569-
}
570-
571-
walk_expr(&fun.body.body, &mut used_params);
572-
used_params
573-
}
574-
575479
/// Generate body operations for GPU functions
576480
fn generate_body_operations(
577481
fun: &crate::ast::FunDef,
@@ -680,6 +584,18 @@ fn generate_body_operations(
680584
// Find the parameter declaration to get its type and index
681585
if let Some((param_idx, param_decl)) = fun.param_decls.iter().enumerate().find(|(_, p)| p.ident.name == ident.name) {
682586
if let Some(param_ty) = &param_decl.ty {
587+
// Check if we're assigning to a reference - if so, it must be unique
588+
if let TyKind::Data(data_ty) = &param_ty.ty {
589+
if let DataTyKind::Ref(ref_dty) = &data_ty.dty {
590+
if ref_dty.own != Ownership::Uniq {
591+
panic!(
592+
"Assignment to non-unique reference is not allowed. Expected unique reference, found {:?}",
593+
ref_dty.own
594+
);
595+
}
596+
}
597+
}
598+
683599
// Generate the target parameter type (should be gm address space)
684600
let target_type = get_mlir_type_string_with_address_space(param_ty, context);
685601

@@ -856,7 +772,7 @@ fn generate_function_with_body(fun: &crate::ast::FunDef, context: &Context) -> S
856772
// TODO: When HACC dialect is registered in the MLIR context, replace this with:
857773
// let hacc_attributes = create_hacc_attributes(context);
858774
// and use the attributes with MLIR operation builders instead of string generation
859-
signature.push_str(&generate_hacc_attributes_string(context));
775+
signature.push_str(" attributes {hacc.entry, hacc.function_kind = #hacc.function_kind<DEVICE>}");
860776
}
861777

862778
signature.push_str(" {\n");

src/lib.rs

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,4 @@ pub fn compile(file_path: &str, backend: Backend) -> Result<(String, String), Er
2929
let ast_string = format!("{:#?}", compil_unit.items);
3030

3131
Ok((code_string, ast_string))
32-
}
33-
34-
pub fn compile_unchecked(file_path: &str, backend: Backend) -> Result<String, ErrorReported> {
35-
let source = parser::SourceCode::from_file(file_path)?;
36-
let mut compil_unit = parser::parse(&source)?;
37-
38-
ty_check::ty_check(&mut compil_unit)?;
39-
40-
let code_string = match backend {
41-
Backend::Cuda => codegen::cuda::gen(&compil_unit, false),
42-
Backend::Mlir => codegen::mlir::gen(&compil_unit, false),
43-
};
44-
45-
Ok(code_string)
46-
}
32+
}

src/main.rs

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@ struct Args {
2121
/// Print Ast
2222
#[arg(short, long)]
2323
print_ast: bool,
24-
25-
/// Skip MLIR verification checks
26-
#[arg(long)]
27-
no_checks: bool,
2824
}
2925

3026
/// Backend selection passed via CLI
@@ -52,23 +48,13 @@ fn main() {
5248
let output_dir = &args.output_dir;
5349

5450
// Compile using Descend
55-
let (code_string, ast_string) = if args.no_checks {
56-
let code_string = match descend::compile_unchecked(&input_path.to_string_lossy(), backend) {
51+
let (code_string, ast_string) = match descend::compile(&input_path.to_string_lossy(), backend) {
5752
Ok(output) => output,
5853
Err(e) => {
5954
eprintln!("Compilation failed: {:?}", e);
6055
std::process::exit(1);
6156
}
62-
};
63-
(code_string, String::new())
64-
} else {
65-
match descend::compile(&input_path.to_string_lossy(), backend) {
66-
Ok(output) => output,
67-
Err(e) => {
68-
eprintln!("Compilation failed: {:?}", e);
69-
std::process::exit(1);
70-
}
71-
}
57+
7258
};
7359

7460
// Generate output file path with appropriate extension based on backend

tests/mlir.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
#[path = "mlir/core.rs"]
22
mod core;
33

4+
#[path = "mlir/error_examples.rs"]
5+
mod error_examples;
6+
47
const BACKEND: descend::Backend = descend::Backend::Mlir;

tests/mlir/core.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
1-
type Res = Result<(), descend::error::ErrorReported>;
2-
31
// Automatically generate tests for all .desc files in examples/core/
42
descend_derive::generate_desc_tests!();

0 commit comments

Comments
 (0)