Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:

- name: Install cargo-gpu
run: |
cargo install --git https://github.com/Rust-GPU/cargo-gpu cargo-gpu
cargo install cargo-gpu --version 0.10.0-alpha.1
cargo gpu install --auto-install-rust-toolchain

- name: Run clippy lints
Expand Down Expand Up @@ -80,7 +80,7 @@ jobs:

- name: Install cargo-gpu
run: |
cargo install --git https://github.com/Rust-GPU/cargo-gpu cargo-gpu
cargo install cargo-gpu --version 0.10.0-alpha.1
cargo gpu install --auto-install-rust-toolchain

- name: Check documentation
Expand All @@ -106,7 +106,7 @@ jobs:
sweep-cache: true
- name: Install cargo-gpu
run: |
cargo install --git https://github.com/Rust-GPU/cargo-gpu cargo-gpu
cargo install cargo-gpu --version 0.10.0-alpha.1
cargo gpu install --auto-install-rust-toolchain
- name: Run Cargo Tests
run: |
Expand Down
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Changelog

_Disclaimer: this changelog is updated using generative AI, but is still verified manually._

## v0.1.0

This shows the changes between the time of open-sourcing the crate and its first release to crates.io:

### Added
- `println!` support for shaders running on the CPU backend (`khal-std`).

### Changed
- Switch `spirv-std` and `spirv-std-macros` to the published `0.10.0-alpha.1` release (previously pinned to a git revision).
- Cache coroutines on the CPU backend for improved performance.
- Enable incremental builds in the workspace to work around a `rust-gpu` issue where the example shader entrypoint was being dropped.
15 changes: 7 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
on any platform: **WebGPU**, **CUDA**, or **CPU** -- from a single codebase.

> **Warning**
> KHAL is still under heavy development.
> KHAL is still under heavy development. The CUDA backend is currently only supported when using the
> github version of `khal-std` (because some dependencies are not available on cartes.io yet). If you
> don’t intend to target cuda, then the published version of `khal-std` is the way to go.

<p align="center">
<img src="./assets/khal-diagram.png" height="400px">
Expand All @@ -25,26 +27,23 @@ on any platform: **WebGPU**, **CUDA**, or **CPU** -- from a single codebase.
- **Proc-macro bindings** -- `#[spirv_bindgen]` generates type-safe host-side structs from your shader function signature.
- **Build pipeline** -- `khal-builder` orchestrates `cargo gpu` and `cargo cuda` to compile shaders at build time.


## Development setup

### cargo-gpu (required for SPIR-V / WebGPU)

The crates.io version of `cargo-gpu` is outdated. Install from Git and let it set up its Rust
toolchain:
Install `cargo-gpu` from crates.io:

```bash
cargo install --git https://github.com/Rust-GPU/cargo-gpu cargo-gpu
cargo install cargo-gpu --version 0.10.0-alpha.1
cargo gpu install
```

### cargo-cuda (required for CUDA / PTX)

`cargo-cuda` lives in this repository (`crates/cargo-cuda`). Install it from the workspace and
build the `rustc_codegen_nvvm` codegen backend:
Install `cargo-cuda` from crates.io:

```bash
cargo install --path https://github.com/dimforge/khal cargo-cuda
cargo install cargo-cuda --version 0.1.0
cargo cuda install
```

Expand Down
5 changes: 3 additions & 2 deletions crates/khal-example-shaders/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![cfg_attr(target_arch = "spirv", no_std)]

use khal_std::glamx::UVec3;
use khal_std::index::MaybeIndexUnchecked;
use khal_std::macros::{spirv, spirv_bindgen};

#[spirv_bindgen]
Expand All @@ -11,7 +12,7 @@ pub fn add_assign(
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] b: &[f32],
) {
let thread_id = invocation_id.x as usize;
if thread_id < a.len() {
a[thread_id] += b[thread_id];
if thread_id < a.len() && thread_id < b.len() {
*a.at_mut(thread_id) += b.read(thread_id);
}
}
Binary file modified crates/khal-example/shaders-spirv/add_assign.spv
Binary file not shown.
4 changes: 2 additions & 2 deletions crates/khal-std/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ unsafe_remove_boundchecks = []
glamx = { version = "0.2", default-features = false, features = ["nostd-libm", "bytemuck"] }
rayon = { version = "1", optional = true }
corosensei = { version = "0.3", optional = true }
spirv-std-macros = { git = "https://github.com/Rust-GPU/rust-gpu.git", rev = "6a67e7b5" }
spirv-std-macros = "0.10.0-alpha.1"
khal-derive = { path = "../khal-derive" }

[lints]
workspace = true

[target.'cfg(not(target_arch = "nvptx64"))'.dependencies]
spirv-std = { git = "https://github.com/Rust-GPU/rust-gpu.git", rev = "6a67e7b5" }
spirv-std = "0.10.0-alpha.1"

[target.'cfg(target_arch = "nvptx64")'.dependencies]
# Fixes the UVec3::element_product bug
Expand Down
99 changes: 73 additions & 26 deletions crates/khal-std/src/arch/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
//! we simulate GPU threads as lightweight stackful coroutines (via `corosensei`)
//! that yield at each barrier. A single OS thread runs all coroutines
//! cooperatively, with zero OS scheduling overhead.
//!
//! Coroutine stacks are pooled per-thread to avoid repeated mmap/munmap
//! syscalls across dispatches.

extern crate std;

use std::cell::Cell;
use std::cell::RefCell;

// =============================================================================
// Barrier: yields the current coroutine back to the scheduler
Expand Down Expand Up @@ -63,13 +67,46 @@ pub fn dispatch_workgroups(num_workgroups: usize, f: impl Fn(u32) + Sync + Send)
}

// =============================================================================
// Intra-workgroup dispatch (using corosensei coroutines)
// Intra-workgroup dispatch (using corosensei coroutines with stack pooling)
// =============================================================================

/// Stack size for coroutines. Shader functions use very little stack space
/// (local variables and small arrays), so 64KB is more than sufficient.
const COROUTINE_STACK_SIZE: usize = 64 * 1024;

thread_local! {
/// Pointer to the active Yielder (null when not in coroutine mode).
/// Each coroutine sets this before calling the work function.
static COROUTINE_YIELDER: Cell<*mut corosensei::Yielder<(), ()>> = const { Cell::new(std::ptr::null_mut()) };

/// Pool of reusable coroutine stacks. Stacks are allocated on first use
/// and returned to the pool after each dispatch, avoiding repeated
/// mmap/munmap syscalls.
static STACK_POOL: RefCell<Vec<corosensei::stack::DefaultStack>> = RefCell::new(Vec::new());
}

/// Takes `count` stacks from the thread-local pool, allocating new ones if needed.
fn take_stacks(count: usize) -> Vec<corosensei::stack::DefaultStack> {
STACK_POOL.with(|pool| {
let mut pool = pool.borrow_mut();
let reusable = count.min(pool.len());
let drain_start = pool.len() - reusable;
let mut stacks: Vec<corosensei::stack::DefaultStack> = pool.drain(drain_start..).collect();
for _ in stacks.len()..count {
stacks.push(
corosensei::stack::DefaultStack::new(COROUTINE_STACK_SIZE)
.expect("failed to allocate coroutine stack"),
);
}
stacks
})
}

/// Returns stacks to the thread-local pool for reuse.
fn return_stacks(stacks: impl IntoIterator<Item = corosensei::stack::DefaultStack>) {
STACK_POOL.with(|pool| {
pool.borrow_mut().extend(stacks);
});
}

/// Dispatches `num_threads` virtual threads using cooperative coroutines.
Expand All @@ -79,6 +116,7 @@ thread_local! {
/// reached the barrier), the scheduler resumes them for the next phase.
///
/// This runs on a single OS thread with zero OS scheduling overhead.
/// Coroutine stacks are pooled to avoid repeated allocation.
pub fn dispatch_workgroup_threads(num_threads: usize, f: impl Fn(u32) + Sync) {
use corosensei::{Coroutine, CoroutineResult};

Expand All @@ -88,41 +126,50 @@ pub fn dispatch_workgroup_threads(num_threads: usize, f: impl Fn(u32) + Sync) {
let f_ref: &'static (dyn Fn(u32) + Sync) =
unsafe { core::mem::transmute(&f as &(dyn Fn(u32) + Sync)) };

// Create one coroutine per virtual thread.
let mut coroutines: Vec<Option<Coroutine<(), (), ()>>> = (0..num_threads)
.map(|tid| {
Some(Coroutine::new(move |yielder, ()| {
// Store the yielder pointer in TLS so barrier_wait() can find it.
COROUTINE_YIELDER.with(|cell| {
cell.set(yielder as *const _ as *mut _);
});
f_ref(tid as u32);
// Clear the yielder pointer.
COROUTINE_YIELDER.with(|cell| cell.set(std::ptr::null_mut()));
}))
})
.collect();
// Take stacks from the pool (reuses existing ones, allocates only if needed).
let stacks = take_stacks(num_threads);

// Create one coroutine per virtual thread, using pooled stacks.
let mut coroutines: Vec<Option<Coroutine<(), (), (), corosensei::stack::DefaultStack>>> =
stacks
.into_iter()
.enumerate()
.map(|(tid, stack)| {
Some(Coroutine::with_stack(stack, move |yielder, ()| {
// Store the yielder pointer in TLS so barrier_wait() can find it.
COROUTINE_YIELDER.with(|cell| {
cell.set(yielder as *const _ as *mut _);
});
f_ref(tid as u32);
// Clear the yielder pointer.
COROUTINE_YIELDER.with(|cell| cell.set(std::ptr::null_mut()));
}))
})
.collect();

// Run all coroutines in round-robin until all complete.
// Each "round" corresponds to one barrier synchronization point.
// Completed coroutines have their stacks recovered for pooling.
let mut recovered_stacks = Vec::with_capacity(num_threads);
loop {
let mut all_done = true;
for slot in coroutines.iter_mut() {
if let Some(coroutine) = slot {
match coroutine.resume(()) {
CoroutineResult::Yield(()) => {
// Coroutine yielded at a barrier — continue to next one.
all_done = false;
}
CoroutineResult::Return(()) => {
// Coroutine completed — remove it.
*slot = None;
}
for i in 0..coroutines.len() {
let result = coroutines[i].as_mut().map(|c| c.resume(()));
match result {
Some(CoroutineResult::Yield(())) => {
all_done = false;
}
Some(CoroutineResult::Return(())) => {
recovered_stacks.push(coroutines[i].take().unwrap().into_stack());
}
None => {}
}
}
if all_done {
break;
}
}

// Return stacks to the pool for reuse by future dispatches.
return_stacks(recovered_stacks);
}
9 changes: 9 additions & 0 deletions crates/khal-std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,12 @@ pub use glamx;

#[cfg(target_arch = "nvptx64")]
pub use cuda_std;

#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))]
pub use std::println;
#[cfg(any(target_arch = "spirv", target_arch = "nvptx64"))]
#[macro_export]
macro_rules! println {
() => {};
($($arg:tt)*) => {};
}
Loading