Skip to content

Commit 8b0fe01

Browse files
committed
feat: handle dynamic arrays
1 parent 42d312f commit 8b0fe01

3 files changed

Lines changed: 54 additions & 13 deletions

File tree

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
fn gpu_vec_add<r: prv, n: nat>(
2+
a_array: &r uniq gpu.global [i32; n],
3+
b_array: &r shrd gpu.global [i32; n]
4+
) -[grid: gpu.grid<X<64>, X<1024>>]-> () {
5+
sched block in grid {
6+
sched thread in block {
7+
let a_ref = &uniq (*a_array).to_view.grp::<1024>[[block]][[thread]];
8+
*a_ref = *a_ref + (*b_array).to_view.grp::<1024>[[block]][[thread]]
9+
}
10+
}
11+
}
12+
13+
fn main() -[t: cpu.thread]-> () <'a>{
14+
let a_array: &'a uniq cpu.mem [i32; 64*1024] = unsafe _;
15+
let b_array: &'a shrd cpu.mem [i32; 64*1024] = unsafe _;
16+
let mut gpu = gpu_device(0);
17+
18+
let mut a_gpu_array = gpu_alloc_copy(&uniq gpu, &shrd *a_array);
19+
let b_gpu_array = gpu_alloc_copy(&uniq gpu, &shrd *b_array);
20+
gpu_vec_add::<<<X<64>, X<1024>>>>(&uniq a_gpu_array, &shrd b_gpu_array);
21+
copy_to_host(&shrd a_gpu_array, a_array);
22+
()
23+
}

src/codegen/mlir/to_mlir/types.rs

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,12 @@ pub trait ToMlir {
1717
}
1818

1919
impl Nat {
20-
fn to_dimension_i64(self: &Self) -> i64 {
20+
fn to_dimension(self: &Self) -> Vec<i64> {
2121
// Try to evaluate the Nat with an empty context
2222
let nat_ctx = NatCtx::new();
2323
match self.eval(&nat_ctx) {
24-
Ok(size) => size as i64,
25-
Err(_) => panic!(
26-
"Array dimensions must be compile-time known. Dynamic arrays are not supported."
27-
),
24+
Ok(size) => vec![size as i64],
25+
Err(_) => vec![],
2826
}
2927
}
3028
}
@@ -149,8 +147,8 @@ fn ref_array_to_mlir<'c>(
149147
) -> Type<'c> {
150148
// Array reference -> memref with dimensions
151149
let elem_type = elem_ty.to_mlir(context);
152-
let dim = size.to_dimension_i64();
153-
let base_type: Type<'c> = MemRefType::new(elem_type, &[dim], None, None).into();
150+
let dim = size.to_dimension();
151+
let base_type: Type<'c> = MemRefType::new(elem_type, &dim, None, None).into();
154152

155153
// Add HIVM address space if needed
156154
let base_str = base_type.to_string();
@@ -169,8 +167,8 @@ fn ref_at_to_mlir<'c>(inner: &DataTy, mem: &Memory, context: &'c Context) -> Typ
169167
}
170168
DataTyKind::Array(elem_ty, size) | DataTyKind::ArrayShape(elem_ty, size) => {
171169
let elem_type = elem_ty.to_mlir(context);
172-
let dim = size.to_dimension_i64();
173-
MemRefType::new(elem_type, &[dim], None, None).into()
170+
let dim = size.to_dimension();
171+
MemRefType::new(elem_type, &dim, None, None).into()
174172
}
175173
DataTyKind::Tuple(_) => {
176174
unimplemented!("Tuple references with At not yet supported in MLIR conversion")
@@ -239,13 +237,13 @@ impl ToMlir for DataTy {
239237
DataTyKind::Ident(ident) => ident.to_mlir(context),
240238
DataTyKind::Array(elem_ty, size) => {
241239
let elem_type = elem_ty.to_mlir(context);
242-
let dim = size.to_dimension_i64();
243-
MemRefType::new(elem_type, &[dim], None, None).into()
240+
let dim = size.to_dimension();
241+
MemRefType::new(elem_type, &dim, None, None).into()
244242
}
245243
DataTyKind::ArrayShape(elem_ty, size) => {
246244
let elem_type = elem_ty.to_mlir(context);
247-
let dim = size.to_dimension_i64();
248-
MemRefType::new(elem_type, &[dim], None, None).into()
245+
let dim = size.to_dimension();
246+
MemRefType::new(elem_type, &dim, None, None).into()
249247
}
250248
DataTyKind::Struct(_) => {
251249
unimplemented!("Struct types not yet supported in MLIR conversion")
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
---
2+
source: tests/mlir/core.rs
3+
expression: output
4+
---
5+
module {
6+
func.func @add(%arg0: memref<16xi16, #hivm.address_space<gm>>, %arg1: memref<16xi16, #hivm.address_space<gm>>, %arg2: memref<16xi16, #hivm.address_space<gm>>, %arg3: memref<16xi16, #hivm.address_space<gm>>) attributes {hacc.entry, hacc.function_kind = #hacc.function_kind<DEVICE>} {
7+
%alloc = memref.alloc() : memref<16xi16, #hivm.address_space<ub>>
8+
hivm.hir.load ins(%arg0 : memref<16xi16, #hivm.address_space<gm>>) outs(%alloc : memref<16xi16, #hivm.address_space<ub>>)
9+
%alloc_0 = memref.alloc() : memref<16xi16, #hivm.address_space<ub>>
10+
hivm.hir.load ins(%arg1 : memref<16xi16, #hivm.address_space<gm>>) outs(%alloc_0 : memref<16xi16, #hivm.address_space<ub>>)
11+
%alloc_1 = memref.alloc() : memref<16xi16, #hivm.address_space<ub>>
12+
hivm.hir.load ins(%arg2 : memref<16xi16, #hivm.address_space<gm>>) outs(%alloc_1 : memref<16xi16, #hivm.address_space<ub>>)
13+
%alloc_2 = memref.alloc() : memref<16xi16, #hivm.address_space<ub>>
14+
hivm.hir.vadd ins(%alloc, %alloc_0 : memref<16xi16, #hivm.address_space<ub>>, memref<16xi16, #hivm.address_space<ub>>) outs(%alloc_2 : memref<16xi16, #hivm.address_space<ub>>)
15+
%alloc_3 = memref.alloc() : memref<16xi16, #hivm.address_space<ub>>
16+
hivm.hir.vadd ins(%alloc_2, %alloc_1 : memref<16xi16, #hivm.address_space<ub>>, memref<16xi16, #hivm.address_space<ub>>) outs(%alloc_3 : memref<16xi16, #hivm.address_space<ub>>)
17+
hivm.hir.store ins(%alloc_3 : memref<16xi16, #hivm.address_space<ub>>) outs(%arg3 : memref<16xi16, #hivm.address_space<gm>>)
18+
return
19+
}
20+
}

0 commit comments

Comments
 (0)