Skip to content

Commit 48f5bd9

Browse files
committed
feat: add type checks for MLIR codegen types
1 parent 9996482 commit 48f5bd9

6 files changed

Lines changed: 113 additions & 1 deletion

File tree

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
fn assign_to_shared_ref<n: nat, r: prv>(
2+
a: &r shrd gpu.global [i16; 16],
3+
b: &r shrd gpu.global [i16; 16]
4+
) -[grid: gpu.grid<X<1>, X<16>>]-> () {
5+
b = a; // This should fail - cannot assign to shared reference
6+
()
7+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Vector addition kernel demonstrating Descend's safe GPU programming model
2+
// This function showcases extended borrow checking, memory safety, and execution context tracking
3+
4+
// Generic function with type parameters:
5+
// - n: nat - Natural number parameter (for array size, though not used in this specific function)
6+
// - r: prv - Provenance parameter tracking memory region/lifetime for all references
7+
fn add<n: nat, r: prv>(
8+
// Shared reference to first input vector - multiple threads can read simultaneously
9+
// Memory space: gpu.global (GPU global memory)
10+
// Ownership: shrd (shared) - prevents write-after-read data races
11+
// Type: 16-element array of 16-bit signed integers
12+
a: &r shrd gpu.global [i16; 16],
13+
14+
// Shared reference to second input vector - multiple threads can read simultaneously
15+
// Same memory space and ownership constraints as 'a'
16+
b: &r shrd gpu.global [i16; 16],
17+
18+
// ERROR: This parameter should be declared as 'unq' (unique) instead of 'shrd' (shared)
19+
// because it's used for assignment (c = a + b). Shared references are read-only and
20+
// prevent data races by allowing multiple concurrent readers. Unique references are
21+
// required for write operations to ensure exclusive access and prevent race conditions.
22+
// The Descend compiler will detect this ownership violation and fail compilation.
23+
c: &r shrd gpu.global [i16; 16]
24+
25+
// Execution context specification - defines how this function runs on GPU hardware
26+
// - grid: gpu.grid<X<1>, X<16>> - GPU execution grid with 1 block containing 16 threads
27+
// - The type system ensures GPU memory is only accessed in GPU execution contexts
28+
// - Prevents invalid cross-device memory accesses (CPU accessing GPU memory)
29+
) -[grid: gpu.grid<X<1>, X<16>>]-> () {
30+
31+
// Vector addition operation - element-wise addition of arrays
32+
// The compiler generates safe parallel code that:
33+
// 1. Loads data from global memory to local memory for each thread
34+
// 2. Performs vectorized addition using HIVM dialect operations
35+
// 3. Stores results back to global memory safely
36+
// The ownership system ensures this operation is race-free
37+
//
38+
// LAZY LOADING: Descend's compiler implements lazy loading strategies:
39+
// - Memory loads are deferred until actually needed by computation
40+
// - The HIVM dialect generates 'hivm.hir.load' operations that load from
41+
// global memory (gm) to local memory (ub) only when data is accessed
42+
// - This minimizes memory bandwidth usage and improves cache efficiency
43+
// - The type system ensures loads happen in the correct execution context
44+
// - Shared references enable read-only access without unnecessary copies
45+
c = a + b;
46+
47+
// Unit return value - indicates successful completion
48+
// In MLIR, this becomes a 'return' operation
49+
()
50+
}

src/codegen/mlir/builder/place.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,19 @@ use super::super::error::MlirError;
44
use super::context::MlirContext;
55
use super::expr::build_expr;
66
use crate::ast as desc;
7+
use crate::ast::{DataTyKind, Ownership};
8+
9+
/// Check if a PlaceExpr represents a unique reference
10+
fn is_unique_reference(place_expr: &desc::PlaceExpr) -> bool {
11+
if let Some(ty) = &place_expr.ty {
12+
if let desc::TyKind::Data(data_ty) = &ty.ty {
13+
if let DataTyKind::Ref(ref_dty) = &data_ty.dty {
14+
return ref_dty.own == Ownership::Uniq;
15+
}
16+
}
17+
}
18+
false
19+
}
720

821
/// Build a place expression (variable lookup)
922
pub fn build_place_expr<'ctx, 'a, 'b>(
@@ -57,6 +70,20 @@ where
5770
{
5871
use desc::PlaceExprKind;
5972

73+
// Check if we're assigning to a reference - if so, it must be unique
74+
if let Some(ty) = &place_expr.ty {
75+
if let desc::TyKind::Data(data_ty) = &ty.ty {
76+
if let DataTyKind::Ref(ref_dty) = &data_ty.dty {
77+
if ref_dty.own != Ownership::Uniq {
78+
return Err(MlirError::General(format!(
79+
"Assignment to non-unique reference is not allowed. Expected unique reference, found {:?}",
80+
ref_dty.own
81+
)));
82+
}
83+
}
84+
}
85+
}
86+
6087
// Evaluate the right-hand side value
6188
let value = build_expr(value_expr, ctx)?
6289
.ok_or_else(|| MlirError::General("Missing value for assignment".to_string()))?;

src/codegen/mlir/to_mlir/types.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::ast::{AtomicTy, BaseExec, DataTy, DataTyKind, FunDef, Memory, Nat, NatCtx, ScalarTy, Ty, TyKind};
1+
use crate::ast::{AtomicTy, BaseExec, DataTy, DataTyKind, FunDef, Memory, Nat, NatCtx, Ownership, ScalarTy, Ty, TyKind};
22
use melior::{
33
dialect::func,
44
ir::{
@@ -680,6 +680,18 @@ fn generate_body_operations(
680680
// Find the parameter declaration to get its type and index
681681
if let Some((param_idx, param_decl)) = fun.param_decls.iter().enumerate().find(|(_, p)| p.ident.name == ident.name) {
682682
if let Some(param_ty) = &param_decl.ty {
683+
// Check if we're assigning to a reference - if so, it must be unique
684+
if let TyKind::Data(data_ty) = &param_ty.ty {
685+
if let DataTyKind::Ref(ref_dty) = &data_ty.dty {
686+
if ref_dty.own != Ownership::Uniq {
687+
panic!(
688+
"Assignment to non-unique reference is not allowed. Expected unique reference, found {:?}",
689+
ref_dty.own
690+
);
691+
}
692+
}
693+
}
694+
683695
// Generate the target parameter type (should be gm address space)
684696
let target_type = get_mlir_type_string_with_address_space(param_ty, context);
685697

tests/mlir.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
#[path = "mlir/core.rs"]
22
mod core;
33

4+
#[path = "mlir/error_examples.rs"]
5+
mod error_examples;
6+
47
const BACKEND: descend::Backend = descend::Backend::Mlir;

tests/mlir/error_examples.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
use super::BACKEND;
2+
3+
#[test]
4+
#[should_panic]
5+
fn vec_add_memory_issue_error() {
6+
descend::compile("examples/error-examples/vec_add_memory_issue.desc", BACKEND).unwrap();
7+
}
8+
9+
#[test]
10+
#[should_panic]
11+
fn assign_to_shared_ref_error() {
12+
descend::compile("examples/error-examples/assign_to_shared_ref.desc", BACKEND).unwrap();
13+
}

0 commit comments

Comments
 (0)