Skip to content

feat: add rust kernels library for loading kernels#421

Open
drbh wants to merge 6 commits intomainfrom
add-kernels-rs
Open

feat: add rust kernels library for loading kernels#421
drbh wants to merge 6 commits intomainfrom
add-kernels-rs

Conversation

@drbh
Copy link
Copy Markdown
Collaborator

@drbh drbh commented Mar 31, 2026

This PR adds a new client library for loading hf kernels in rust. This allow tvmffi based kernels to be called from rust and optionally integrates with candle for a better tensor ux.

Example usage with candle

use candle_core::{Device, Tensor};
use kernels::Result;
use kernels::candle::CallKernel;

fn main() -> Result<()> {
    let activation = kernels::candle::get_kernel("drbh/relu-tvm", 1)?;
    let device = activation.device()?;
    println!("Backend: {}", activation.backend());

    let x = Tensor::new(&[-1.0f32, 2.0, -3.0, 4.0, -0.5, 0.0, 1.5, -2.5], &device)?;
    let y = Tensor::zeros_like(&x)?;
    activation.call("relu", &[&y, &x])?;

    let result = y.to_vec1::<f32>()?;
    let expected = Tensor::new(&*x.to_vec1::<f32>()?, &Device::Cpu)?
        .relu()?
        .to_vec1::<f32>()?;

    println!("Input:    {:?}", x.to_vec1::<f32>()?);
    println!("TVM FFI:  {result:?}");
    println!("Candle:   {expected:?}");
    assert_eq!(result, expected);
    println!("OK");
    Ok(())
}

repo with candle and non candle examples https://github.com/drbh/hf-kernels-rust

@drbh drbh marked this pull request as ready for review April 1, 2026 15:01
Comment on lines +74 to +93
pub fn detect_cuda_version() -> Option<String> {
cuda_version_from_smi().or_else(cuda_version_from_nvcc)
}

fn cuda_version_from_smi() -> Option<String> {
let output = Command::new("nvidia-smi").output().ok()?;
if !output.status.success() {
return None;
}
let stdout = String::from_utf8_lossy(&output.stdout);
let rest = stdout.split("CUDA Version:").nth(1)?;
Some(rest.split_whitespace().next()?.to_string())
}

fn cuda_version_from_nvcc() -> Option<String> {
let output = Command::new("nvcc").arg("--version").output().ok()?;
let stdout = String::from_utf8_lossy(&output.stdout);
let after = stdout.split("release ").nth(1)?;
Some(after.split(',').next()?.trim().to_string())
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not be the same as the library that a framework is compiled against and dynamically loads. Also, nvidia-smi gives the driver library version, not the CUDA runtime version. We need to get it from cudart, e.g. see:

def _get_cuda() -> Optional[CUDA]:

libloading seems to be the most widely used library for dlopen:

https://github.com/nagisa/rust_libloading/

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, I've updated to prefer querying via cudaRuntimeGetVersion from cudart in the latest changes. I've tested locally and am not running into any issues - however I'm not 100% sure if we need more logic to search for the cudart like ctypes.util.find_library("cudart") does if its not in the default location

Comment on lines +15 to +24
pub fn candle_device(self) -> Result<Device> {
match self {
BackendKind::Cpu => Ok(Device::Cpu),
#[cfg(feature = "candle-cuda")]
BackendKind::Cuda => Device::new_cuda(0).map_err(Into::into),
#[cfg(not(feature = "candle-cuda"))]
BackendKind::Cuda => Ok(Device::Cpu),
BackendKind::Xpu => Ok(Device::Cpu),
}
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be TryFrom<BackendDevice> for Device. Not 100% sure if it works with the coherency rules, since it's a different mod in the same crate. But I think it should.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is much nicer, thanks for the suggestion! updated in latest

Comment on lines +26 to +34
pub fn candle_supported(self) -> Self {
match self {
#[cfg(feature = "candle-cuda")]
BackendKind::Cuda => BackendKind::Cuda,
#[cfg(not(feature = "candle-cuda"))]
BackendKind::Cuda => BackendKind::Cpu,
other => other,
}
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function name is not very descriptive, maybe to_candle_supported?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sound good to me, updated in latest

Comment on lines +43 to +44
#[allow(unreachable_patterns)]
_ => BackendKind::Cpu,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to explicitly enumerate the other variants here, so that we can rely on exhaustiveness checking when other variants get added?

Also it seems that as it is, if Candle returns a device type that we don't support, it would result in Cpu, which results in kernels that are not compatible with the device type?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed thats a much better approach. I've updated to enumerate the Devices in latest and throw errors if the device is not supported by the currently tvmffi impl

Comment on lines +75 to +79
macro_rules! ptr {
($v:expr) => {
Ok(unsafe { $v.as_ptr().add(offset) as *mut c_void })
};
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove, make explicit.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed and opt'ed to add *_slice_data_ptr functions for both cpu and cuda to avoid the macros in both places.

Comment on lines +101 to +107
macro_rules! ptr {
($slice:expr) => {{
let view = $slice.slice(offset..);
let (device_ptr, _sync) = view.device_ptr(&stream);
Ok(device_ptr as *mut c_void)
}};
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think rather than a macro, this could be a trait + impl? At least I think with a generic type it should work with one implementation for all cases?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated to remove the macros (comment above) and explored a trait but ended up settling on a generic function like

fn cuda_slice_data_ptr<T>(
    slice: &cudarc::driver::CudaSlice<T>,
    stream: &cudarc::driver::CudaStream,
    offset: usize,
) -> Result<*mut c_void> {

this helped make the cpu and cuda function follow a similar functional pattern

note: the cpu path uses a generic function like

fn cpu_slice_data_ptr<T>(slice: &[T], offset: usize) -> Result<*mut c_void> {

happy to explore another approach if you see any issues with this! thanks!

// Tensors are passed to the kernel as DLPack pointers directly into
// candle's storage - no copies for contiguous tensors.
pub trait CallKernel {
fn call(&self, func_name: &str, args: &[&Tensor]) -> Result<()>;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if there are non-tensor argument, e.g. option bools, epsilon floats, etc.?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch I originally only tested with a kernel that expected tensors. updated to handle multiple types in the latest changes

Comment on lines +183 to +187
let device_type = match kind {
BackendKind::Cpu => tvm_ffi::DL_CPU,
BackendKind::Cuda => tvm_ffi::DL_CUDA,
BackendKind::Xpu => tvm_ffi::DL_ONEAPI,
};
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this could use a From implementation outside the function?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, updated in latest

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants