feat: add rust kernels library for loading kernels#421
Conversation
| pub fn detect_cuda_version() -> Option<String> { | ||
| cuda_version_from_smi().or_else(cuda_version_from_nvcc) | ||
| } | ||
|
|
||
| fn cuda_version_from_smi() -> Option<String> { | ||
| let output = Command::new("nvidia-smi").output().ok()?; | ||
| if !output.status.success() { | ||
| return None; | ||
| } | ||
| let stdout = String::from_utf8_lossy(&output.stdout); | ||
| let rest = stdout.split("CUDA Version:").nth(1)?; | ||
| Some(rest.split_whitespace().next()?.to_string()) | ||
| } | ||
|
|
||
| fn cuda_version_from_nvcc() -> Option<String> { | ||
| let output = Command::new("nvcc").arg("--version").output().ok()?; | ||
| let stdout = String::from_utf8_lossy(&output.stdout); | ||
| let after = stdout.split("release ").nth(1)?; | ||
| Some(after.split(',').next()?.trim().to_string()) | ||
| } |
There was a problem hiding this comment.
This may not be the same as the library that a framework is compiled against and dynamically loads. Also, nvidia-smi gives the driver library version, not the CUDA runtime version. We need to get it from cudart, e.g. see:
kernels/kernels/src/kernels/backends.py
Line 254 in 8ed7bb4
libloading seems to be the most widely used library for dlopen:
There was a problem hiding this comment.
good catch, I've updated to prefer querying via cudaRuntimeGetVersion from cudart in the latest changes. I've tested locally and am not running into any issues - however I'm not 100% sure if we need more logic to search for the cudart like ctypes.util.find_library("cudart") does if its not in the default location
kernels-rs/src/candle.rs
Outdated
| pub fn candle_device(self) -> Result<Device> { | ||
| match self { | ||
| BackendKind::Cpu => Ok(Device::Cpu), | ||
| #[cfg(feature = "candle-cuda")] | ||
| BackendKind::Cuda => Device::new_cuda(0).map_err(Into::into), | ||
| #[cfg(not(feature = "candle-cuda"))] | ||
| BackendKind::Cuda => Ok(Device::Cpu), | ||
| BackendKind::Xpu => Ok(Device::Cpu), | ||
| } | ||
| } |
There was a problem hiding this comment.
I think this can be TryFrom<BackendDevice> for Device. Not 100% sure if it works with the coherency rules, since it's a different mod in the same crate. But I think it should.
There was a problem hiding this comment.
this is much nicer, thanks for the suggestion! updated in latest
kernels-rs/src/candle.rs
Outdated
| pub fn candle_supported(self) -> Self { | ||
| match self { | ||
| #[cfg(feature = "candle-cuda")] | ||
| BackendKind::Cuda => BackendKind::Cuda, | ||
| #[cfg(not(feature = "candle-cuda"))] | ||
| BackendKind::Cuda => BackendKind::Cpu, | ||
| other => other, | ||
| } | ||
| } |
There was a problem hiding this comment.
The function name is not very descriptive, maybe to_candle_supported?
There was a problem hiding this comment.
sound good to me, updated in latest
kernels-rs/src/candle.rs
Outdated
| #[allow(unreachable_patterns)] | ||
| _ => BackendKind::Cpu, |
There was a problem hiding this comment.
I think it would be better to explicitly enumerate the other variants here, so that we can rely on exhaustiveness checking when other variants get added?
Also it seems that as it is, if Candle returns a device type that we don't support, it would result in Cpu, which results in kernels that are not compatible with the device type?
There was a problem hiding this comment.
agreed thats a much better approach. I've updated to enumerate the Devices in latest and throw errors if the device is not supported by the currently tvmffi impl
kernels-rs/src/candle.rs
Outdated
| macro_rules! ptr { | ||
| ($v:expr) => { | ||
| Ok(unsafe { $v.as_ptr().add(offset) as *mut c_void }) | ||
| }; | ||
| } |
There was a problem hiding this comment.
removed and opt'ed to add *_slice_data_ptr functions for both cpu and cuda to avoid the macros in both places.
kernels-rs/src/candle.rs
Outdated
| macro_rules! ptr { | ||
| ($slice:expr) => {{ | ||
| let view = $slice.slice(offset..); | ||
| let (device_ptr, _sync) = view.device_ptr(&stream); | ||
| Ok(device_ptr as *mut c_void) | ||
| }}; | ||
| } |
There was a problem hiding this comment.
I think rather than a macro, this could be a trait + impl? At least I think with a generic type it should work with one implementation for all cases?
There was a problem hiding this comment.
updated to remove the macros (comment above) and explored a trait but ended up settling on a generic function like
fn cuda_slice_data_ptr<T>(
slice: &cudarc::driver::CudaSlice<T>,
stream: &cudarc::driver::CudaStream,
offset: usize,
) -> Result<*mut c_void> {this helped make the cpu and cuda function follow a similar functional pattern
note: the cpu path uses a generic function like
fn cpu_slice_data_ptr<T>(slice: &[T], offset: usize) -> Result<*mut c_void> {happy to explore another approach if you see any issues with this! thanks!
kernels-rs/src/candle.rs
Outdated
| // Tensors are passed to the kernel as DLPack pointers directly into | ||
| // candle's storage - no copies for contiguous tensors. | ||
| pub trait CallKernel { | ||
| fn call(&self, func_name: &str, args: &[&Tensor]) -> Result<()>; |
There was a problem hiding this comment.
What if there are non-tensor argument, e.g. option bools, epsilon floats, etc.?
There was a problem hiding this comment.
good catch I originally only tested with a kernel that expected tensors. updated to handle multiple types in the latest changes
kernels-rs/src/candle.rs
Outdated
| let device_type = match kind { | ||
| BackendKind::Cpu => tvm_ffi::DL_CPU, | ||
| BackendKind::Cuda => tvm_ffi::DL_CUDA, | ||
| BackendKind::Xpu => tvm_ffi::DL_ONEAPI, | ||
| }; |
There was a problem hiding this comment.
Seems like this could use a From implementation outside the function?
There was a problem hiding this comment.
good point, updated in latest
This PR adds a new client library for loading hf kernels in rust. This allow tvmffi based kernels to be called from rust and optionally integrates with candle for a better tensor ux.
Example usage with candle
repo with candle and non candle examples https://github.com/drbh/hf-kernels-rust