diff --git a/vortex-cuda/src/device_buffer.rs b/vortex-cuda/src/device_buffer.rs index 17bcd44f5d4..d5967d2b804 100644 --- a/vortex-cuda/src/device_buffer.rs +++ b/vortex-cuda/src/device_buffer.rs @@ -153,15 +153,14 @@ impl CudaBufferExt for BufferHandle { } fn cuda_device_ptr(&self) -> VortexResult { - let ptr = self + let alloc = self .as_device_opt() .ok_or_else(|| vortex_err!("Buffer is not on device"))? .as_any() .downcast_ref::() - .ok_or_else(|| vortex_err!("expected CudaDeviceBuffer"))? - .device_ptr; + .ok_or_else(|| vortex_err!("expected CudaDeviceBuffer"))?; - Ok(ptr) + Ok(alloc.device_ptr + alloc.offset as u64) } } @@ -335,3 +334,29 @@ impl DeviceBuffer for CudaDeviceBuffer { } } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use vortex_array::LEGACY_SESSION; + use vortex_array::buffer::BufferHandle; + + use crate::CudaBufferExt; + use crate::CudaDeviceBuffer; + use crate::CudaSession; + + #[crate::test] + fn test_device_ptr() { + let ctx = CudaSession::create_execution_ctx(&*LEGACY_SESSION).unwrap(); + let handle1 = BufferHandle::new_device(Arc::new(CudaDeviceBuffer::new( + ctx.device_alloc::(1024).unwrap(), + ))); + + let handle2 = handle1.slice_typed::(10..1024); + assert_eq!( + handle2.cuda_device_ptr().unwrap(), + handle1.cuda_device_ptr().unwrap() + 40 + ); + } +}