From 50851da257fa5c0ef6a91255d9cf457c4245e555 Mon Sep 17 00:00:00 2001 From: cjchanh Date: Mon, 13 Apr 2026 16:42:48 -0600 Subject: [PATCH 1/3] Add fused Metal q4 path for MLX 4-bit models --- cake-core/src/backends/metal/mod.rs | 1474 ++++++++++++++--- cake-core/src/backends/metal/ops.msl | 362 ++++ cake-core/src/backends/mod.rs | 64 +- cake-core/src/models/common/attention.rs | 235 ++- cake-core/src/models/common/mlp.rs | 113 +- cake-core/src/utils/gptq.rs | 281 +++- cake-core/src/utils/mlx_quant.rs | 226 +++ cake-core/src/utils/mod.rs | 145 +- cake-core/src/utils/quantized_linear.rs | 158 ++ .../tests/unit_tests/test_quantization.rs | 617 ++++++- 10 files changed, 3344 insertions(+), 331 deletions(-) create mode 100644 cake-core/src/utils/mlx_quant.rs create mode 100644 cake-core/src/utils/quantized_linear.rs diff --git a/cake-core/src/backends/metal/mod.rs b/cake-core/src/backends/metal/mod.rs index 64874f0..4feabdb 100644 --- a/cake-core/src/backends/metal/mod.rs +++ b/cake-core/src/backends/metal/mod.rs @@ -8,7 +8,9 @@ //! The `synchronize()` method flushes the command buffer and is called at strategic //! points during forward passes (see GatedDeltaNet, Qwen3_5FullAttention). -use candle_core::{backend::BackendStorage as _, CpuStorage, DType, Device, Layout, Result, Shape, Tensor}; +use candle_core::{ + CpuStorage, DType, Device, Layout, Result, Shape, Tensor, backend::BackendStorage as _, +}; use super::ComputeBackend; @@ -22,28 +24,50 @@ const FUSED_OPS_MSL: &str = include_str!("ops.msl"); /// All kernel names in the MSL source — compiled eagerly on first access. const ALL_KERNELS: &[&str] = &[ - "gelu_f32", "gelu_f16", - "sigmoid_f32", "sigmoid_f16", - "silu_f32", "silu_f16", - "stable_softplus_f32", "stable_softplus_f16", - "silu_mul_f32", "silu_mul_f16", - "add3_f32", "add3_f16", - "exp_mul_f32", "exp_mul_f16", - "sub_mul_f32", "sub_mul_f16", - "add_scaled_f32", "add_scaled_f16", - "depthwise_conv1d_silu_f32", "depthwise_conv1d_silu_f16", - "depthwise_conv1d_bias_f32", "depthwise_conv1d_bias_f16", - "rms_norm_f32", "rms_norm_f16", - "rms_norm_gated_f32", "rms_norm_gated_f16", - "add_rms_norm_f32", "add_rms_norm_f16", - "rms_norm_channel_f32", "rms_norm_channel_f16", - "f8e4m3_to_f32", "f8e4m3_to_f16", - "adaln_modulate_f32", "adaln_modulate_f16", - "softmax_last_dim_f32", "softmax_last_dim_f16", - "layer_norm_f32", "layer_norm_f16", - "rope_f32", "rope_f16", + "gelu_f32", + "gelu_f16", + "sigmoid_f32", + "sigmoid_f16", + "silu_f32", + "silu_f16", + "stable_softplus_f32", + "stable_softplus_f16", + "silu_mul_f32", + "silu_mul_f16", + "add3_f32", + "add3_f16", + "exp_mul_f32", + "exp_mul_f16", + "sub_mul_f32", + "sub_mul_f16", + "add_scaled_f32", + "add_scaled_f16", + "depthwise_conv1d_silu_f32", + "depthwise_conv1d_silu_f16", + "depthwise_conv1d_bias_f32", + "depthwise_conv1d_bias_f16", + "rms_norm_f32", + "rms_norm_f16", + "rms_norm_gated_f32", + "rms_norm_gated_f16", + "add_rms_norm_f32", + "add_rms_norm_f16", + "rms_norm_channel_f32", + "rms_norm_channel_f16", + "f8e4m3_to_f32", + "f8e4m3_to_f16", + "adaln_modulate_f32", + "adaln_modulate_f16", + "softmax_last_dim_f32", + "softmax_last_dim_f16", + "layer_norm_f32", + "layer_norm_f16", + "rope_f32", + "rope_f16", "fused_vector_attention_f16", "fused_vector_attention_f32", + "q4_matvec_f16", + "q4_matmul_tiled_f16", ]; struct PipelineCache { @@ -69,17 +93,26 @@ impl PipelineCache { return Ok(pipeline.clone()); } } - let _guard = self.compile_lock.lock().map_err(|e| candle_core::Error::Msg(format!("compile lock: {e}")))?; + let _guard = self + .compile_lock + .lock() + .map_err(|e| candle_core::Error::Msg(format!("compile lock: {e}")))?; if let Ok(cache) = self.pipelines.read() { if let Some(pipeline) = cache.get(kernel_name) { return Ok(pipeline.clone()); } } - let lib = device.new_library_with_source(FUSED_OPS_MSL, None) + let lib = device + .new_library_with_source(FUSED_OPS_MSL, None) .map_err(|e| candle_core::Error::Msg(format!("metal shader compile: {e}")))?; - let mut cache = self.pipelines.write().map_err(|e| candle_core::Error::Msg(format!("pipeline write lock: {e}")))?; + let mut cache = self + .pipelines + .write() + .map_err(|e| candle_core::Error::Msg(format!("pipeline write lock: {e}")))?; for &name in ALL_KERNELS { - if cache.contains_key(name) { continue; } + if cache.contains_key(name) { + continue; + } if let Ok(func) = lib.get_function(name, None) { if let Ok(pipeline) = device.new_compute_pipeline_state_with_function(&func) { cache.insert(name, pipeline); @@ -92,15 +125,20 @@ impl PipelineCache { } } -static PIPELINE_CACHE: std::sync::LazyLock = std::sync::LazyLock::new(PipelineCache::new); +static PIPELINE_CACHE: std::sync::LazyLock = + std::sync::LazyLock::new(PipelineCache::new); // ─── Helper: dispatch an elementwise 2-input kernel ───────────────── #[inline] fn dispatch_binary( - s1: &candle_core::MetalStorage, l1: &Layout, - s2: &candle_core::MetalStorage, l2: &Layout, - f32_kernel: &'static str, f16_kernel: &'static str, label: &'static str, + s1: &candle_core::MetalStorage, + l1: &Layout, + s2: &candle_core::MetalStorage, + l2: &Layout, + f32_kernel: &'static str, + f16_kernel: &'static str, + label: &'static str, ) -> Result<(candle_core::MetalStorage, Shape)> { let device = s1.device(); let el = l1.shape().elem_count(); @@ -119,22 +157,36 @@ fn dispatch_binary( candle_metal_kernels::utils::set_param(&encoder, 1, (s2.buffer(), off2)); candle_metal_kernels::utils::set_param(&encoder, 2, (&*output, 0usize)); candle_metal_kernels::utils::set_param(&encoder, 3, el as u32); - let grid = objc2_metal::MTLSize { width: el, height: 1, depth: 1 }; + let grid = objc2_metal::MTLSize { + width: el, + height: 1, + depth: 1, + }; let group = candle_metal_kernels::utils::get_block_dims(el, 1, 1); encoder.dispatch_threads(grid, group); - Ok((candle_core::MetalStorage::new(output, device.clone(), el, s1.dtype()), l1.shape().clone())) + Ok(( + candle_core::MetalStorage::new(output, device.clone(), el, s1.dtype()), + l1.shape().clone(), + )) } // ─── Helper: dispatch an elementwise 3-input kernel ───────────────── -struct TernaryKernel { f32_kernel: &'static str, f16_kernel: &'static str, label: &'static str } +struct TernaryKernel { + f32_kernel: &'static str, + f16_kernel: &'static str, + label: &'static str, +} #[allow(clippy::too_many_arguments)] #[inline] fn dispatch_ternary( - s1: &candle_core::MetalStorage, l1: &Layout, - s2: &candle_core::MetalStorage, l2: &Layout, - s3: &candle_core::MetalStorage, l3: &Layout, + s1: &candle_core::MetalStorage, + l1: &Layout, + s2: &candle_core::MetalStorage, + l2: &Layout, + s3: &candle_core::MetalStorage, + l3: &Layout, k: &TernaryKernel, ) -> Result<(candle_core::MetalStorage, Shape)> { let device = s1.device(); @@ -156,21 +208,39 @@ fn dispatch_ternary( candle_metal_kernels::utils::set_param(&encoder, 2, (s3.buffer(), off3)); candle_metal_kernels::utils::set_param(&encoder, 3, (&*output, 0usize)); candle_metal_kernels::utils::set_param(&encoder, 4, el as u32); - let grid = objc2_metal::MTLSize { width: el, height: 1, depth: 1 }; + let grid = objc2_metal::MTLSize { + width: el, + height: 1, + depth: 1, + }; let group = candle_metal_kernels::utils::get_block_dims(el, 1, 1); encoder.dispatch_threads(grid, group); - Ok((candle_core::MetalStorage::new(output, device.clone(), el, s1.dtype()), l1.shape().clone())) + Ok(( + candle_core::MetalStorage::new(output, device.clone(), el, s1.dtype()), + l1.shape().clone(), + )) } -const ADD3_KERNEL: TernaryKernel = TernaryKernel { f32_kernel: "add3_f32", f16_kernel: "add3_f16", label: "add3" }; -const SUB_MUL_KERNEL: TernaryKernel = TernaryKernel { f32_kernel: "sub_mul_f32", f16_kernel: "sub_mul_f16", label: "sub_mul" }; +const ADD3_KERNEL: TernaryKernel = TernaryKernel { + f32_kernel: "add3_f32", + f16_kernel: "add3_f16", + label: "add3", +}; +const SUB_MUL_KERNEL: TernaryKernel = TernaryKernel { + f32_kernel: "sub_mul_f32", + f16_kernel: "sub_mul_f16", + label: "sub_mul", +}; // ─── Helper: dispatch a unary elementwise kernel ───────────────────── #[inline] fn dispatch_unary( - s: &candle_core::MetalStorage, l: &Layout, - f32_kernel: &'static str, f16_kernel: &'static str, label: &'static str, + s: &candle_core::MetalStorage, + l: &Layout, + f32_kernel: &'static str, + f16_kernel: &'static str, + label: &'static str, ) -> Result<(candle_core::MetalStorage, Shape)> { let device = s.device(); let el = l.shape().elem_count(); @@ -187,46 +257,97 @@ fn dispatch_unary( candle_metal_kernels::utils::set_param(&encoder, 0, (s.buffer(), offset)); candle_metal_kernels::utils::set_param(&encoder, 1, (&*output, 0usize)); candle_metal_kernels::utils::set_param(&encoder, 2, el as u32); - let grid = objc2_metal::MTLSize { width: el, height: 1, depth: 1 }; + let grid = objc2_metal::MTLSize { + width: el, + height: 1, + depth: 1, + }; let group = candle_metal_kernels::utils::get_block_dims(el, 1, 1); encoder.dispatch_threads(grid, group); - Ok((candle_core::MetalStorage::new(output, device.clone(), el, s.dtype()), l.shape().clone())) + Ok(( + candle_core::MetalStorage::new(output, device.clone(), el, s.dtype()), + l.shape().clone(), + )) } // ─── CustomOp structs ─────────────────────────────────────────────── struct MetalGelu; impl candle_core::CustomOp1 for MetalGelu { - fn name(&self) -> &'static str { "metal_gelu" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalGelu: expected Metal device") } - fn metal_fwd(&self, s: &candle_core::MetalStorage, l: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { dispatch_unary(s, l, "gelu_f32", "gelu_f16", "gelu") } + fn name(&self) -> &'static str { + "metal_gelu" + } + fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalGelu: expected Metal device") + } + fn metal_fwd( + &self, + s: &candle_core::MetalStorage, + l: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { + dispatch_unary(s, l, "gelu_f32", "gelu_f16", "gelu") + } } struct MetalSigmoid; impl candle_core::CustomOp1 for MetalSigmoid { - fn name(&self) -> &'static str { "metal_sigmoid" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalSigmoid: expected Metal device") } - fn metal_fwd(&self, s: &candle_core::MetalStorage, l: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { dispatch_unary(s, l, "sigmoid_f32", "sigmoid_f16", "sigmoid") } + fn name(&self) -> &'static str { + "metal_sigmoid" + } + fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalSigmoid: expected Metal device") + } + fn metal_fwd( + &self, + s: &candle_core::MetalStorage, + l: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { + dispatch_unary(s, l, "sigmoid_f32", "sigmoid_f16", "sigmoid") + } } struct MetalSilu; impl candle_core::CustomOp1 for MetalSilu { - fn name(&self) -> &'static str { "metal_silu" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalSilu: expected Metal device") } - fn metal_fwd(&self, s: &candle_core::MetalStorage, l: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { dispatch_unary(s, l, "silu_f32", "silu_f16", "silu") } + fn name(&self) -> &'static str { + "metal_silu" + } + fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalSilu: expected Metal device") + } + fn metal_fwd( + &self, + s: &candle_core::MetalStorage, + l: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { + dispatch_unary(s, l, "silu_f32", "silu_f16", "silu") + } } struct MetalSoftmaxLastDim; impl candle_core::CustomOp1 for MetalSoftmaxLastDim { - fn name(&self) -> &'static str { "metal_softmax_last_dim" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalSoftmaxLastDim: expected Metal device") } - fn metal_fwd(&self, s: &candle_core::MetalStorage, l: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { + fn name(&self) -> &'static str { + "metal_softmax_last_dim" + } + fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalSoftmaxLastDim: expected Metal device") + } + fn metal_fwd( + &self, + s: &candle_core::MetalStorage, + l: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { let device = s.device(); let dims = l.shape().dims(); let el = l.shape().elem_count(); - let last_dim = *dims.last().ok_or_else(|| candle_core::Error::Msg("empty shape".into()))?; + let last_dim = *dims + .last() + .ok_or_else(|| candle_core::Error::Msg("empty shape".into()))?; let num_rows = el / last_dim; - let kernel_name: &'static str = match s.dtype() { DType::F32 => "softmax_last_dim_f32", DType::F16 => "softmax_last_dim_f16", dt => candle_core::bail!("softmax metal: unsupported dtype {dt:?}") }; + let kernel_name: &'static str = match s.dtype() { + DType::F32 => "softmax_last_dim_f32", + DType::F16 => "softmax_last_dim_f16", + dt => candle_core::bail!("softmax metal: unsupported dtype {dt:?}"), + }; let pipeline = PIPELINE_CACHE.get_or_create(device, kernel_name)?; let output = device.new_buffer(el, s.dtype(), "softmax")?; let encoder = device.command_encoder()?; @@ -237,24 +358,59 @@ impl candle_core::CustomOp1 for MetalSoftmaxLastDim { candle_metal_kernels::utils::set_param(&encoder, 2, last_dim as u32); let max_threads = pipeline.max_total_threads_per_threadgroup(); let tg_width = last_dim.min(max_threads); - let grid = objc2_metal::MTLSize { width: last_dim, height: num_rows, depth: 1 }; - let group = objc2_metal::MTLSize { width: tg_width, height: 1, depth: 1 }; + let grid = objc2_metal::MTLSize { + width: last_dim, + height: num_rows, + depth: 1, + }; + let group = objc2_metal::MTLSize { + width: tg_width, + height: 1, + depth: 1, + }; encoder.dispatch_threads(grid, group); - Ok((candle_core::MetalStorage::new(output, device.clone(), el, s.dtype()), l.shape().clone())) + Ok(( + candle_core::MetalStorage::new(output, device.clone(), el, s.dtype()), + l.shape().clone(), + )) } } struct MetalRope; impl candle_core::CustomOp3 for MetalRope { - fn name(&self) -> &'static str { "metal_rope" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalRope: expected Metal device") } + fn name(&self) -> &'static str { + "metal_rope" + } + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalRope: expected Metal device") + } #[allow(clippy::too_many_arguments)] - fn metal_fwd(&self, s_x: &candle_core::MetalStorage, l_x: &Layout, s_cos: &candle_core::MetalStorage, l_cos: &Layout, s_sin: &candle_core::MetalStorage, l_sin: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { + fn metal_fwd( + &self, + s_x: &candle_core::MetalStorage, + l_x: &Layout, + s_cos: &candle_core::MetalStorage, + l_cos: &Layout, + s_sin: &candle_core::MetalStorage, + l_sin: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { let device = s_x.device(); let dims = l_x.shape().dims(); let el = l_x.shape().elem_count(); let (head_dim, seq_len) = (dims[dims.len() - 1], dims[dims.len() - 2]); - let kernel_name: &'static str = match s_x.dtype() { DType::F32 => "rope_f32", DType::F16 => "rope_f16", dt => candle_core::bail!("rope metal: unsupported dtype {dt:?}") }; + let kernel_name: &'static str = match s_x.dtype() { + DType::F32 => "rope_f32", + DType::F16 => "rope_f16", + dt => candle_core::bail!("rope metal: unsupported dtype {dt:?}"), + }; let pipeline = PIPELINE_CACHE.get_or_create(device, kernel_name)?; let output = device.new_buffer(el, s_x.dtype(), "rope")?; let encoder = device.command_encoder()?; @@ -268,10 +424,17 @@ impl candle_core::CustomOp3 for MetalRope { candle_metal_kernels::utils::set_param(&encoder, 3, (&*output, 0usize)); candle_metal_kernels::utils::set_param(&encoder, 4, head_dim as u32); candle_metal_kernels::utils::set_param(&encoder, 5, seq_len as u32); - let grid = objc2_metal::MTLSize { width: el, height: 1, depth: 1 }; + let grid = objc2_metal::MTLSize { + width: el, + height: 1, + depth: 1, + }; let group = candle_metal_kernels::utils::get_block_dims(el, 1, 1); encoder.dispatch_threads(grid, group); - Ok((candle_core::MetalStorage::new(output, device.clone(), el, s_x.dtype()), l_x.shape().clone())) + Ok(( + candle_core::MetalStorage::new(output, device.clone(), el, s_x.dtype()), + l_x.shape().clone(), + )) } } @@ -281,15 +444,37 @@ struct MetalLayerNorm { bias_layout: Layout, } impl candle_core::CustomOp2 for MetalLayerNorm { - fn name(&self) -> &'static str { "metal_layer_norm" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalLayerNorm: expected Metal device") } - fn metal_fwd(&self, s_x: &candle_core::MetalStorage, l_x: &Layout, s_w: &candle_core::MetalStorage, _l_w: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { + fn name(&self) -> &'static str { + "metal_layer_norm" + } + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalLayerNorm: expected Metal device") + } + fn metal_fwd( + &self, + s_x: &candle_core::MetalStorage, + l_x: &Layout, + s_w: &candle_core::MetalStorage, + _l_w: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { let device = s_x.device(); let dims = l_x.shape().dims(); let el = l_x.shape().elem_count(); - let hidden = *dims.last().ok_or_else(|| candle_core::Error::Msg("empty shape".into()))?; + let hidden = *dims + .last() + .ok_or_else(|| candle_core::Error::Msg("empty shape".into()))?; let num_rows = el / hidden; - let kernel_name: &'static str = match s_x.dtype() { DType::F32 => "layer_norm_f32", DType::F16 => "layer_norm_f16", dt => candle_core::bail!("layer_norm metal: unsupported dtype {dt:?}") }; + let kernel_name: &'static str = match s_x.dtype() { + DType::F32 => "layer_norm_f32", + DType::F16 => "layer_norm_f16", + dt => candle_core::bail!("layer_norm metal: unsupported dtype {dt:?}"), + }; let pipeline = PIPELINE_CACHE.get_or_create(device, kernel_name)?; let output = device.new_buffer(el, s_x.dtype(), "layer_norm")?; let encoder = device.command_encoder()?; @@ -304,68 +489,194 @@ impl candle_core::CustomOp2 for MetalLayerNorm { candle_metal_kernels::utils::set_param(&encoder, 5, self.eps); let max_threads = pipeline.max_total_threads_per_threadgroup(); let tg_width = hidden.min(max_threads); - let grid = objc2_metal::MTLSize { width: hidden, height: num_rows, depth: 1 }; - let group = objc2_metal::MTLSize { width: tg_width, height: 1, depth: 1 }; + let grid = objc2_metal::MTLSize { + width: hidden, + height: num_rows, + depth: 1, + }; + let group = objc2_metal::MTLSize { + width: tg_width, + height: 1, + depth: 1, + }; encoder.dispatch_threads(grid, group); - Ok((candle_core::MetalStorage::new(output, device.clone(), el, s_x.dtype()), l_x.shape().clone())) + Ok(( + candle_core::MetalStorage::new(output, device.clone(), el, s_x.dtype()), + l_x.shape().clone(), + )) } } struct MetalSiluMul; impl candle_core::CustomOp2 for MetalSiluMul { - fn name(&self) -> &'static str { "metal_silu_mul" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalSiluMul: expected Metal device") } - fn metal_fwd(&self, s1: &candle_core::MetalStorage, l1: &Layout, s2: &candle_core::MetalStorage, l2: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { + fn name(&self) -> &'static str { + "metal_silu_mul" + } + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalSiluMul: expected Metal device") + } + fn metal_fwd( + &self, + s1: &candle_core::MetalStorage, + l1: &Layout, + s2: &candle_core::MetalStorage, + l2: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { dispatch_binary(s1, l1, s2, l2, "silu_mul_f32", "silu_mul_f16", "silu_mul") } } struct MetalStableSoftplus; impl candle_core::CustomOp1 for MetalStableSoftplus { - fn name(&self) -> &'static str { "metal_stable_softplus" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalStableSoftplus: expected Metal device") } - fn metal_fwd(&self, s: &candle_core::MetalStorage, l: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { dispatch_unary(s, l, "stable_softplus_f32", "stable_softplus_f16", "stable_softplus") + fn name(&self) -> &'static str { + "metal_stable_softplus" + } + fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalStableSoftplus: expected Metal device") + } + fn metal_fwd( + &self, + s: &candle_core::MetalStorage, + l: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { + dispatch_unary( + s, + l, + "stable_softplus_f32", + "stable_softplus_f16", + "stable_softplus", + ) } } struct MetalExpMul; impl candle_core::CustomOp2 for MetalExpMul { - fn name(&self) -> &'static str { "metal_exp_mul" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalExpMul: expected Metal device") } - fn metal_fwd(&self, s1: &candle_core::MetalStorage, l1: &Layout, s2: &candle_core::MetalStorage, l2: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { + fn name(&self) -> &'static str { + "metal_exp_mul" + } + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalExpMul: expected Metal device") + } + fn metal_fwd( + &self, + s1: &candle_core::MetalStorage, + l1: &Layout, + s2: &candle_core::MetalStorage, + l2: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { dispatch_binary(s1, l1, s2, l2, "exp_mul_f32", "exp_mul_f16", "exp_mul") } } struct MetalAdd3; impl candle_core::CustomOp3 for MetalAdd3 { - fn name(&self) -> &'static str { "metal_add3" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalAdd3: expected Metal device") } - fn metal_fwd(&self, s1: &candle_core::MetalStorage, l1: &Layout, s2: &candle_core::MetalStorage, l2: &Layout, s3: &candle_core::MetalStorage, l3: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { + fn name(&self) -> &'static str { + "metal_add3" + } + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalAdd3: expected Metal device") + } + fn metal_fwd( + &self, + s1: &candle_core::MetalStorage, + l1: &Layout, + s2: &candle_core::MetalStorage, + l2: &Layout, + s3: &candle_core::MetalStorage, + l3: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { dispatch_ternary(s1, l1, s2, l2, s3, l3, &ADD3_KERNEL) } } struct MetalSubMul; impl candle_core::CustomOp3 for MetalSubMul { - fn name(&self) -> &'static str { "metal_sub_mul" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalSubMul: expected Metal device") } - fn metal_fwd(&self, s1: &candle_core::MetalStorage, l1: &Layout, s2: &candle_core::MetalStorage, l2: &Layout, s3: &candle_core::MetalStorage, l3: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { + fn name(&self) -> &'static str { + "metal_sub_mul" + } + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalSubMul: expected Metal device") + } + fn metal_fwd( + &self, + s1: &candle_core::MetalStorage, + l1: &Layout, + s2: &candle_core::MetalStorage, + l2: &Layout, + s3: &candle_core::MetalStorage, + l3: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { dispatch_ternary(s1, l1, s2, l2, s3, l3, &SUB_MUL_KERNEL) } } struct MetalAddScaled; impl candle_core::CustomOp3 for MetalAddScaled { - fn name(&self) -> &'static str { "metal_add_scaled" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalAddScaled: expected Metal device") } + fn name(&self) -> &'static str { + "metal_add_scaled" + } + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalAddScaled: expected Metal device") + } #[allow(clippy::too_many_arguments)] - fn metal_fwd(&self, s1: &candle_core::MetalStorage, l1: &Layout, s2: &candle_core::MetalStorage, l2: &Layout, s3: &candle_core::MetalStorage, l3: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { + fn metal_fwd( + &self, + s1: &candle_core::MetalStorage, + l1: &Layout, + s2: &candle_core::MetalStorage, + l2: &Layout, + s3: &candle_core::MetalStorage, + l3: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { let device = s1.device(); let el = l1.shape().elem_count(); let dims = l1.shape().dims(); - let (channels, time_len) = if dims.len() >= 3 { (dims[dims.len() - 2], dims[dims.len() - 1]) } else { (1usize, el) }; - let kernel_name: &'static str = match s1.dtype() { DType::F32 => "add_scaled_f32", DType::F16 => "add_scaled_f16", dt => candle_core::bail!("add_scaled metal: unsupported dtype {dt:?}") }; + let (channels, time_len) = if dims.len() >= 3 { + (dims[dims.len() - 2], dims[dims.len() - 1]) + } else { + (1usize, el) + }; + let kernel_name: &'static str = match s1.dtype() { + DType::F32 => "add_scaled_f32", + DType::F16 => "add_scaled_f16", + dt => candle_core::bail!("add_scaled metal: unsupported dtype {dt:?}"), + }; let pipeline = PIPELINE_CACHE.get_or_create(device, kernel_name)?; let output = device.new_buffer(el, s1.dtype(), "add_scaled")?; let encoder = device.command_encoder()?; @@ -380,23 +691,50 @@ impl candle_core::CustomOp3 for MetalAddScaled { candle_metal_kernels::utils::set_param(&encoder, 4, el as u32); candle_metal_kernels::utils::set_param(&encoder, 5, channels as u32); candle_metal_kernels::utils::set_param(&encoder, 6, time_len as u32); - let grid = objc2_metal::MTLSize { width: el, height: 1, depth: 1 }; + let grid = objc2_metal::MTLSize { + width: el, + height: 1, + depth: 1, + }; let group = candle_metal_kernels::utils::get_block_dims(el, 1, 1); encoder.dispatch_threads(grid, group); - Ok((candle_core::MetalStorage::new(output, device.clone(), el, s1.dtype()), l1.shape().clone())) + Ok(( + candle_core::MetalStorage::new(output, device.clone(), el, s1.dtype()), + l1.shape().clone(), + )) } } struct MetalDepthwiseConv1dSilu; impl candle_core::CustomOp2 for MetalDepthwiseConv1dSilu { - fn name(&self) -> &'static str { "metal_depthwise_conv1d_silu" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalDepthwiseConv1dSilu: expected Metal device") } - fn metal_fwd(&self, s_win: &candle_core::MetalStorage, l_win: &Layout, s_wt: &candle_core::MetalStorage, l_wt: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { + fn name(&self) -> &'static str { + "metal_depthwise_conv1d_silu" + } + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalDepthwiseConv1dSilu: expected Metal device") + } + fn metal_fwd( + &self, + s_win: &candle_core::MetalStorage, + l_win: &Layout, + s_wt: &candle_core::MetalStorage, + l_wt: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { let device = s_win.device(); let win_dims = l_win.shape().dims(); let (batch, channels, kernel_size) = (win_dims[0], win_dims[1], win_dims[2]); let out_count = batch * channels; - let kernel_name: &'static str = match s_win.dtype() { DType::F32 => "depthwise_conv1d_silu_f32", DType::F16 => "depthwise_conv1d_silu_f16", dt => candle_core::bail!("depthwise_conv1d_silu metal: unsupported dtype {dt:?}") }; + let kernel_name: &'static str = match s_win.dtype() { + DType::F32 => "depthwise_conv1d_silu_f32", + DType::F16 => "depthwise_conv1d_silu_f16", + dt => candle_core::bail!("depthwise_conv1d_silu metal: unsupported dtype {dt:?}"), + }; let pipeline = PIPELINE_CACHE.get_or_create(device, kernel_name)?; let output = device.new_buffer(out_count, s_win.dtype(), "dw_conv1d_silu")?; let encoder = device.command_encoder()?; @@ -409,19 +747,46 @@ impl candle_core::CustomOp2 for MetalDepthwiseConv1dSilu { candle_metal_kernels::utils::set_param(&encoder, 3, out_count as u32); candle_metal_kernels::utils::set_param(&encoder, 4, channels as u32); candle_metal_kernels::utils::set_param(&encoder, 5, kernel_size as u32); - let grid = objc2_metal::MTLSize { width: out_count, height: 1, depth: 1 }; + let grid = objc2_metal::MTLSize { + width: out_count, + height: 1, + depth: 1, + }; let group = candle_metal_kernels::utils::get_block_dims(out_count, 1, 1); encoder.dispatch_threads(grid, group); - Ok((candle_core::MetalStorage::new(output, device.clone(), out_count, s_win.dtype()), Shape::from(vec![batch, channels]))) + Ok(( + candle_core::MetalStorage::new(output, device.clone(), out_count, s_win.dtype()), + Shape::from(vec![batch, channels]), + )) } } struct MetalDepthwiseConv1dBias; impl candle_core::CustomOp3 for MetalDepthwiseConv1dBias { - fn name(&self) -> &'static str { "metal_depthwise_conv1d_bias" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalDepthwiseConv1dBias: expected Metal device") } + fn name(&self) -> &'static str { + "metal_depthwise_conv1d_bias" + } + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalDepthwiseConv1dBias: expected Metal device") + } #[allow(clippy::too_many_arguments)] - fn metal_fwd(&self, s_in: &candle_core::MetalStorage, l_in: &Layout, s_wt: &candle_core::MetalStorage, l_wt: &Layout, s_bias: &candle_core::MetalStorage, l_bias: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { + fn metal_fwd( + &self, + s_in: &candle_core::MetalStorage, + l_in: &Layout, + s_wt: &candle_core::MetalStorage, + l_wt: &Layout, + s_bias: &candle_core::MetalStorage, + l_bias: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { let device = s_in.device(); let in_dims = l_in.shape().dims(); let (batch, channels, t_padded) = (in_dims[0], in_dims[1], in_dims[2]); @@ -429,7 +794,11 @@ impl candle_core::CustomOp3 for MetalDepthwiseConv1dBias { let kernel_size = wt_dims[wt_dims.len() - 1]; let out_len = t_padded - kernel_size + 1; let out_count = batch * channels * out_len; - let kernel_name: &'static str = match s_in.dtype() { DType::F32 => "depthwise_conv1d_bias_f32", DType::F16 => "depthwise_conv1d_bias_f16", dt => candle_core::bail!("depthwise_conv1d_bias metal: unsupported dtype {dt:?}") }; + let kernel_name: &'static str = match s_in.dtype() { + DType::F32 => "depthwise_conv1d_bias_f32", + DType::F16 => "depthwise_conv1d_bias_f16", + dt => candle_core::bail!("depthwise_conv1d_bias metal: unsupported dtype {dt:?}"), + }; let pipeline = PIPELINE_CACHE.get_or_create(device, kernel_name)?; let output = device.new_buffer(out_count, s_in.dtype(), "dw_conv1d_bias")?; let encoder = device.command_encoder()?; @@ -446,24 +815,55 @@ impl candle_core::CustomOp3 for MetalDepthwiseConv1dBias { candle_metal_kernels::utils::set_param(&encoder, 6, out_len as u32); candle_metal_kernels::utils::set_param(&encoder, 7, t_padded as u32); candle_metal_kernels::utils::set_param(&encoder, 8, kernel_size as u32); - let grid = objc2_metal::MTLSize { width: out_count, height: 1, depth: 1 }; + let grid = objc2_metal::MTLSize { + width: out_count, + height: 1, + depth: 1, + }; let group = candle_metal_kernels::utils::get_block_dims(out_count, 1, 1); encoder.dispatch_threads(grid, group); - Ok((candle_core::MetalStorage::new(output, device.clone(), out_count, s_in.dtype()), Shape::from(vec![batch, channels, out_len]))) + Ok(( + candle_core::MetalStorage::new(output, device.clone(), out_count, s_in.dtype()), + Shape::from(vec![batch, channels, out_len]), + )) } } -struct MetalRmsNorm { eps: f32 } +struct MetalRmsNorm { + eps: f32, +} impl candle_core::CustomOp2 for MetalRmsNorm { - fn name(&self) -> &'static str { "metal_rms_norm" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalRmsNorm: expected Metal device") } - fn metal_fwd(&self, s_x: &candle_core::MetalStorage, l_x: &Layout, s_w: &candle_core::MetalStorage, _l_w: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { + fn name(&self) -> &'static str { + "metal_rms_norm" + } + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalRmsNorm: expected Metal device") + } + fn metal_fwd( + &self, + s_x: &candle_core::MetalStorage, + l_x: &Layout, + s_w: &candle_core::MetalStorage, + _l_w: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { let device = s_x.device(); let dims = l_x.shape().dims(); let el = l_x.shape().elem_count(); - let hidden = *dims.last().ok_or_else(|| candle_core::Error::Msg("empty shape".into()))?; + let hidden = *dims + .last() + .ok_or_else(|| candle_core::Error::Msg("empty shape".into()))?; let num_rows = el / hidden; - let kernel_name: &'static str = match s_x.dtype() { DType::F32 => "rms_norm_f32", DType::F16 => "rms_norm_f16", dt => candle_core::bail!("rms_norm metal: unsupported dtype {dt:?}") }; + let kernel_name: &'static str = match s_x.dtype() { + DType::F32 => "rms_norm_f32", + DType::F16 => "rms_norm_f16", + dt => candle_core::bail!("rms_norm metal: unsupported dtype {dt:?}"), + }; let pipeline = PIPELINE_CACHE.get_or_create(device, kernel_name)?; let output = device.new_buffer(el, s_x.dtype(), "rms_norm")?; let encoder = device.command_encoder()?; @@ -476,25 +876,64 @@ impl candle_core::CustomOp2 for MetalRmsNorm { candle_metal_kernels::utils::set_param(&encoder, 4, self.eps); let max_threads = pipeline.max_total_threads_per_threadgroup(); let tg_width = hidden.min(max_threads); - let grid = objc2_metal::MTLSize { width: hidden, height: num_rows, depth: 1 }; - let group = objc2_metal::MTLSize { width: tg_width, height: 1, depth: 1 }; + let grid = objc2_metal::MTLSize { + width: hidden, + height: num_rows, + depth: 1, + }; + let group = objc2_metal::MTLSize { + width: tg_width, + height: 1, + depth: 1, + }; encoder.dispatch_threads(grid, group); - Ok((candle_core::MetalStorage::new(output, device.clone(), el, s_x.dtype()), l_x.shape().clone())) + Ok(( + candle_core::MetalStorage::new(output, device.clone(), el, s_x.dtype()), + l_x.shape().clone(), + )) } } -struct MetalRmsNormGated { eps: f32 } +struct MetalRmsNormGated { + eps: f32, +} impl candle_core::CustomOp3 for MetalRmsNormGated { - fn name(&self) -> &'static str { "metal_rms_norm_gated" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalRmsNormGated: expected Metal device") } + fn name(&self) -> &'static str { + "metal_rms_norm_gated" + } + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalRmsNormGated: expected Metal device") + } #[allow(clippy::too_many_arguments)] - fn metal_fwd(&self, s_x: &candle_core::MetalStorage, l_x: &Layout, s_z: &candle_core::MetalStorage, l_z: &Layout, s_w: &candle_core::MetalStorage, l_w: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { + fn metal_fwd( + &self, + s_x: &candle_core::MetalStorage, + l_x: &Layout, + s_z: &candle_core::MetalStorage, + l_z: &Layout, + s_w: &candle_core::MetalStorage, + l_w: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { let device = s_x.device(); let dims = l_x.shape().dims(); let el = l_x.shape().elem_count(); - let hidden = *dims.last().ok_or_else(|| candle_core::Error::Msg("empty shape".into()))?; + let hidden = *dims + .last() + .ok_or_else(|| candle_core::Error::Msg("empty shape".into()))?; let num_rows = el / hidden; - let kernel_name: &'static str = match s_x.dtype() { DType::F32 => "rms_norm_gated_f32", DType::F16 => "rms_norm_gated_f16", dt => candle_core::bail!("rms_norm_gated metal: unsupported dtype {dt:?}") }; + let kernel_name: &'static str = match s_x.dtype() { + DType::F32 => "rms_norm_gated_f32", + DType::F16 => "rms_norm_gated_f16", + dt => candle_core::bail!("rms_norm_gated metal: unsupported dtype {dt:?}"), + }; let pipeline = PIPELINE_CACHE.get_or_create(device, kernel_name)?; let output = device.new_buffer(el, s_x.dtype(), "rms_norm_gated")?; let encoder = device.command_encoder()?; @@ -510,25 +949,64 @@ impl candle_core::CustomOp3 for MetalRmsNormGated { candle_metal_kernels::utils::set_param(&encoder, 5, self.eps); let max_threads = pipeline.max_total_threads_per_threadgroup(); let tg_width = hidden.min(max_threads); - let grid = objc2_metal::MTLSize { width: hidden, height: num_rows, depth: 1 }; - let group = objc2_metal::MTLSize { width: tg_width, height: 1, depth: 1 }; + let grid = objc2_metal::MTLSize { + width: hidden, + height: num_rows, + depth: 1, + }; + let group = objc2_metal::MTLSize { + width: tg_width, + height: 1, + depth: 1, + }; encoder.dispatch_threads(grid, group); - Ok((candle_core::MetalStorage::new(output, device.clone(), el, s_x.dtype()), l_x.shape().clone())) + Ok(( + candle_core::MetalStorage::new(output, device.clone(), el, s_x.dtype()), + l_x.shape().clone(), + )) } } -struct MetalAddRmsNorm { eps: f32 } +struct MetalAddRmsNorm { + eps: f32, +} impl candle_core::CustomOp3 for MetalAddRmsNorm { - fn name(&self) -> &'static str { "metal_add_rms_norm" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalAddRmsNorm: expected Metal device") } + fn name(&self) -> &'static str { + "metal_add_rms_norm" + } + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalAddRmsNorm: expected Metal device") + } #[allow(clippy::too_many_arguments)] - fn metal_fwd(&self, s_a: &candle_core::MetalStorage, l_a: &Layout, s_b: &candle_core::MetalStorage, l_b: &Layout, s_w: &candle_core::MetalStorage, l_w: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { + fn metal_fwd( + &self, + s_a: &candle_core::MetalStorage, + l_a: &Layout, + s_b: &candle_core::MetalStorage, + l_b: &Layout, + s_w: &candle_core::MetalStorage, + l_w: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { let device = s_a.device(); let dims = l_a.shape().dims(); let el = l_a.shape().elem_count(); - let hidden = *dims.last().ok_or_else(|| candle_core::Error::Msg("empty shape".into()))?; + let hidden = *dims + .last() + .ok_or_else(|| candle_core::Error::Msg("empty shape".into()))?; let num_rows = el / hidden; - let kernel_name: &'static str = match s_a.dtype() { DType::F32 => "add_rms_norm_f32", DType::F16 => "add_rms_norm_f16", dt => candle_core::bail!("add_rms_norm metal: unsupported dtype {dt:?}") }; + let kernel_name: &'static str = match s_a.dtype() { + DType::F32 => "add_rms_norm_f32", + DType::F16 => "add_rms_norm_f16", + dt => candle_core::bail!("add_rms_norm metal: unsupported dtype {dt:?}"), + }; let pipeline = PIPELINE_CACHE.get_or_create(device, kernel_name)?; let output = device.new_buffer(2 * el, s_a.dtype(), "add_rms_norm")?; let encoder = device.command_encoder()?; @@ -545,24 +1023,57 @@ impl candle_core::CustomOp3 for MetalAddRmsNorm { candle_metal_kernels::utils::set_param(&encoder, 6, self.eps); let max_threads = pipeline.max_total_threads_per_threadgroup(); let tg_width = hidden.min(max_threads); - let grid = objc2_metal::MTLSize { width: hidden, height: num_rows, depth: 1 }; - let group = objc2_metal::MTLSize { width: tg_width, height: 1, depth: 1 }; + let grid = objc2_metal::MTLSize { + width: hidden, + height: num_rows, + depth: 1, + }; + let group = objc2_metal::MTLSize { + width: tg_width, + height: 1, + depth: 1, + }; encoder.dispatch_threads(grid, group); - Ok((candle_core::MetalStorage::new(output, device.clone(), 2 * el, s_a.dtype()), Shape::from(vec![2 * el]))) + Ok(( + candle_core::MetalStorage::new(output, device.clone(), 2 * el, s_a.dtype()), + Shape::from(vec![2 * el]), + )) } } -struct MetalRmsNormChannel { eps: f32 } +struct MetalRmsNormChannel { + eps: f32, +} impl candle_core::CustomOp2 for MetalRmsNormChannel { - fn name(&self) -> &'static str { "metal_rms_norm_channel" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalRmsNormChannel: expected Metal device") } - fn metal_fwd(&self, s_x: &candle_core::MetalStorage, l_x: &Layout, s_w: &candle_core::MetalStorage, l_w: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { + fn name(&self) -> &'static str { + "metal_rms_norm_channel" + } + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalRmsNormChannel: expected Metal device") + } + fn metal_fwd( + &self, + s_x: &candle_core::MetalStorage, + l_x: &Layout, + s_w: &candle_core::MetalStorage, + l_w: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { let device = s_x.device(); let dims = l_x.shape().dims(); let el = l_x.shape().elem_count(); let (batch, channels, time_len) = (dims[0], dims[1], dims[2]); let num_bt = batch * time_len; - let kernel_name: &'static str = match s_x.dtype() { DType::F32 => "rms_norm_channel_f32", DType::F16 => "rms_norm_channel_f16", dt => candle_core::bail!("rms_norm_channel metal: unsupported dtype {dt:?}") }; + let kernel_name: &'static str = match s_x.dtype() { + DType::F32 => "rms_norm_channel_f32", + DType::F16 => "rms_norm_channel_f16", + dt => candle_core::bail!("rms_norm_channel metal: unsupported dtype {dt:?}"), + }; let pipeline = PIPELINE_CACHE.get_or_create(device, kernel_name)?; let output = device.new_buffer(el, s_x.dtype(), "rms_norm_channel")?; let encoder = device.command_encoder()?; @@ -577,18 +1088,32 @@ impl candle_core::CustomOp2 for MetalRmsNormChannel { candle_metal_kernels::utils::set_param(&encoder, 5, self.eps); let max_threads = pipeline.max_total_threads_per_threadgroup(); let tg_width = channels.min(max_threads); - let grid = objc2_metal::MTLSize { width: channels, height: num_bt, depth: 1 }; - let group = objc2_metal::MTLSize { width: tg_width, height: 1, depth: 1 }; + let grid = objc2_metal::MTLSize { + width: channels, + height: num_bt, + depth: 1, + }; + let group = objc2_metal::MTLSize { + width: tg_width, + height: 1, + depth: 1, + }; encoder.dispatch_threads(grid, group); - Ok((candle_core::MetalStorage::new(output, device.clone(), el, s_x.dtype()), l_x.shape().clone())) + Ok(( + candle_core::MetalStorage::new(output, device.clone(), el, s_x.dtype()), + l_x.shape().clone(), + )) } } /// Like dispatch_unary but with a fixed kernel name and output dtype (for type-casting kernels). #[inline] fn dispatch_unary_cast( - s: &candle_core::MetalStorage, l: &Layout, - kernel_name: &'static str, out_dtype: DType, label: &'static str, + s: &candle_core::MetalStorage, + l: &Layout, + kernel_name: &'static str, + out_dtype: DType, + label: &'static str, ) -> Result<(candle_core::MetalStorage, Shape)> { let device = s.device(); let el = l.shape().elem_count(); @@ -600,24 +1125,51 @@ fn dispatch_unary_cast( candle_metal_kernels::utils::set_param(&encoder, 0, (s.buffer(), offset)); candle_metal_kernels::utils::set_param(&encoder, 1, (&*output, 0usize)); candle_metal_kernels::utils::set_param(&encoder, 2, el as u32); - let grid = objc2_metal::MTLSize { width: el, height: 1, depth: 1 }; + let grid = objc2_metal::MTLSize { + width: el, + height: 1, + depth: 1, + }; let group = candle_metal_kernels::utils::get_block_dims(el, 1, 1); encoder.dispatch_threads(grid, group); - Ok((candle_core::MetalStorage::new(output, device.clone(), el, out_dtype), l.shape().clone())) + Ok(( + candle_core::MetalStorage::new(output, device.clone(), el, out_dtype), + l.shape().clone(), + )) } struct MetalF8ToF32; impl candle_core::CustomOp1 for MetalF8ToF32 { - fn name(&self) -> &'static str { "metal_f8e4m3_to_f32" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalF8ToF32: expected Metal device") } - fn metal_fwd(&self, s: &candle_core::MetalStorage, l: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { dispatch_unary_cast(s, l, "f8e4m3_to_f32", DType::F32, "f8_to_f32") } + fn name(&self) -> &'static str { + "metal_f8e4m3_to_f32" + } + fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalF8ToF32: expected Metal device") + } + fn metal_fwd( + &self, + s: &candle_core::MetalStorage, + l: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { + dispatch_unary_cast(s, l, "f8e4m3_to_f32", DType::F32, "f8_to_f32") + } } struct MetalF8ToF16; impl candle_core::CustomOp1 for MetalF8ToF16 { - fn name(&self) -> &'static str { "metal_f8e4m3_to_f16" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalF8ToF16: expected Metal device") } - fn metal_fwd(&self, s: &candle_core::MetalStorage, l: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { dispatch_unary_cast(s, l, "f8e4m3_to_f16", DType::F16, "f8_to_f16") } + fn name(&self) -> &'static str { + "metal_f8e4m3_to_f16" + } + fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalF8ToF16: expected Metal device") + } + fn metal_fwd( + &self, + s: &candle_core::MetalStorage, + l: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { + dispatch_unary_cast(s, l, "f8e4m3_to_f16", DType::F16, "f8_to_f16") + } } struct MetalAdalnModulate { @@ -626,16 +1178,42 @@ struct MetalAdalnModulate { shift_layout: Layout, } impl candle_core::CustomOp3 for MetalAdalnModulate { - fn name(&self) -> &'static str { "metal_adaln_modulate" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalAdalnModulate: expected Metal device") } + fn name(&self) -> &'static str { + "metal_adaln_modulate" + } + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalAdalnModulate: expected Metal device") + } #[allow(clippy::too_many_arguments)] - fn metal_fwd(&self, s_x: &candle_core::MetalStorage, l_x: &Layout, s_w: &candle_core::MetalStorage, _l_w: &Layout, s_scale: &candle_core::MetalStorage, l_scale: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { + fn metal_fwd( + &self, + s_x: &candle_core::MetalStorage, + l_x: &Layout, + s_w: &candle_core::MetalStorage, + _l_w: &Layout, + s_scale: &candle_core::MetalStorage, + l_scale: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { let device = s_x.device(); let dims = l_x.shape().dims(); let el = l_x.shape().elem_count(); - let hidden = *dims.last().ok_or_else(|| candle_core::Error::Msg("empty shape".into()))?; + let hidden = *dims + .last() + .ok_or_else(|| candle_core::Error::Msg("empty shape".into()))?; let num_rows = el / hidden; - let kernel_name: &'static str = match s_x.dtype() { DType::F32 => "adaln_modulate_f32", DType::F16 => "adaln_modulate_f16", dt => candle_core::bail!("adaln_modulate metal: unsupported dtype {dt:?}") }; + let kernel_name: &'static str = match s_x.dtype() { + DType::F32 => "adaln_modulate_f32", + DType::F16 => "adaln_modulate_f16", + dt => candle_core::bail!("adaln_modulate metal: unsupported dtype {dt:?}"), + }; let pipeline = PIPELINE_CACHE.get_or_create(device, kernel_name)?; let output = device.new_buffer(el, s_x.dtype(), "adaln_modulate")?; let encoder = device.command_encoder()?; @@ -643,29 +1221,68 @@ impl candle_core::CustomOp3 for MetalAdalnModulate { let off_x = l_x.start_offset() * s_x.dtype().size_in_bytes(); let off_w = 0usize; // weight is always from start let off_scale = l_scale.start_offset() * s_scale.dtype().size_in_bytes(); - let off_shift = self.shift_layout.start_offset() * self.shift_storage.dtype().size_in_bytes(); + let off_shift = + self.shift_layout.start_offset() * self.shift_storage.dtype().size_in_bytes(); candle_metal_kernels::utils::set_param(&encoder, 0, (s_x.buffer(), off_x)); candle_metal_kernels::utils::set_param(&encoder, 1, (s_w.buffer(), off_w)); candle_metal_kernels::utils::set_param(&encoder, 2, (s_scale.buffer(), off_scale)); - candle_metal_kernels::utils::set_param(&encoder, 3, (self.shift_storage.buffer(), off_shift)); + candle_metal_kernels::utils::set_param( + &encoder, + 3, + (self.shift_storage.buffer(), off_shift), + ); candle_metal_kernels::utils::set_param(&encoder, 4, (&*output, 0usize)); candle_metal_kernels::utils::set_param(&encoder, 5, hidden as u32); candle_metal_kernels::utils::set_param(&encoder, 6, self.eps); let max_threads = pipeline.max_total_threads_per_threadgroup(); let tg_width = hidden.min(max_threads); - let grid = objc2_metal::MTLSize { width: hidden, height: num_rows, depth: 1 }; - let group = objc2_metal::MTLSize { width: tg_width, height: 1, depth: 1 }; + let grid = objc2_metal::MTLSize { + width: hidden, + height: num_rows, + depth: 1, + }; + let group = objc2_metal::MTLSize { + width: tg_width, + height: 1, + depth: 1, + }; encoder.dispatch_threads(grid, group); - Ok((candle_core::MetalStorage::new(output, device.clone(), el, s_x.dtype()), l_x.shape().clone())) + Ok(( + candle_core::MetalStorage::new(output, device.clone(), el, s_x.dtype()), + l_x.shape().clone(), + )) } } -struct MetalFusedVectorAttention { scale: f32, gqa_ratio: u32 } +struct MetalFusedVectorAttention { + scale: f32, + gqa_ratio: u32, +} impl candle_core::CustomOp3 for MetalFusedVectorAttention { - fn name(&self) -> &'static str { "metal_fused_vector_attention" } - fn cpu_fwd(&self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { candle_core::bail!("MetalFusedVectorAttention: expected Metal device") } + fn name(&self) -> &'static str { + "metal_fused_vector_attention" + } + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalFusedVectorAttention: expected Metal device") + } #[allow(clippy::too_many_arguments)] - fn metal_fwd(&self, s_q: &candle_core::MetalStorage, l_q: &Layout, s_k: &candle_core::MetalStorage, l_k: &Layout, s_v: &candle_core::MetalStorage, l_v: &Layout) -> Result<(candle_core::MetalStorage, Shape)> { + fn metal_fwd( + &self, + s_q: &candle_core::MetalStorage, + l_q: &Layout, + s_k: &candle_core::MetalStorage, + l_k: &Layout, + s_v: &candle_core::MetalStorage, + l_v: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { // Q: (batch*heads, head_dim), K/V: (batch*kv_heads, kv_len, head_dim) let device = s_q.device(); let q_dims = l_q.shape().dims(); @@ -696,13 +1313,215 @@ impl candle_core::CustomOp3 for MetalFusedVectorAttention { // Grid: (head_dim, batch*heads) — one column per head_dim element, one row per head let max_threads = pipeline.max_total_threads_per_threadgroup(); let tg_width = head_dim.min(max_threads); - let grid = objc2_metal::MTLSize { width: head_dim, height: bh, depth: 1 }; - let group = objc2_metal::MTLSize { width: tg_width, height: 1, depth: 1 }; + let grid = objc2_metal::MTLSize { + width: head_dim, + height: bh, + depth: 1, + }; + let group = objc2_metal::MTLSize { + width: tg_width, + height: 1, + depth: 1, + }; encoder.dispatch_threads(grid, group); - Ok((candle_core::MetalStorage::new(output, device.clone(), bh * head_dim, out_dtype), Shape::from(vec![bh, head_dim]))) + Ok(( + candle_core::MetalStorage::new(output, device.clone(), bh * head_dim, out_dtype), + Shape::from(vec![bh, head_dim]), + )) } } +// ─── Fused 4-bit dequant + matmul ─────────────────────────────────── + +/// Fused 4-bit quantized matmul on Metal. +/// +/// Reads packed uint32 weights (8 x 4-bit nibbles each), F16 per-group +/// scales and biases, and F16 activations. Dequantizes on-the-fly per +/// output element and accumulates the dot product in F32, writing F16 +/// output. This keeps weights at 0.5 bytes/element on GPU instead of +/// expanding to F16 (2 bytes/element) — a 4x memory reduction. +/// +/// Input tensors (passed via `apply_op3_no_bwd`): +/// s1 = packed: (out_features, packed_cols) U32 +/// s2 = scales: (out_features, num_groups) F16 +/// s3 = biases: (out_features, num_groups) F16 +/// +/// The activation tensor `x` (F16) and dimension parameters are captured +/// in the struct since CustomOp3 only supports 3 tensor inputs. +pub(crate) struct MetalQ4MatmulF16 { + /// Activation tensor: (M, in_features), F16, on Metal device. + pub x_storage: candle_core::MetalStorage, + pub x_layout: Layout, + /// Number of rows in activation matrix (batch dimension). + pub m: u32, + /// Number of input features (columns in x, rows in weight). + pub in_features: u32, + /// Number of output features (rows in packed weight). + pub out_features: u32, + /// Quantization group size (typically 32, 64, or 128). + pub group_size: u32, + /// Number of quantization groups = in_features / group_size. + pub num_groups: u32, +} + +impl candle_core::CustomOp3 for MetalQ4MatmulF16 { + fn name(&self) -> &'static str { + if self.m == 1 { + "metal_q4_matvec_f16" + } else { + "metal_q4_matmul_tiled_f16" + } + } + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("MetalQ4MatmulF16: expected Metal device") + } + #[allow(clippy::too_many_arguments)] + fn metal_fwd( + &self, + s_packed: &candle_core::MetalStorage, + l_packed: &Layout, + s_scales: &candle_core::MetalStorage, + l_scales: &Layout, + s_biases: &candle_core::MetalStorage, + l_biases: &Layout, + ) -> Result<(candle_core::MetalStorage, Shape)> { + let device = s_packed.device(); + let kernel_name = if self.m == 1 { + "q4_matvec_f16" + } else { + "q4_matmul_tiled_f16" + }; + let pipeline = PIPELINE_CACHE.get_or_create(device, kernel_name)?; + let out_el = self.m as usize * self.out_features as usize; + let output = device.new_buffer(out_el, DType::F16, "q4_matmul")?; + let encoder = device.command_encoder()?; + encoder.set_compute_pipeline_state(&pipeline); + + let off_packed = l_packed.start_offset() * s_packed.dtype().size_in_bytes(); + let off_scales = l_scales.start_offset() * s_scales.dtype().size_in_bytes(); + let off_biases = l_biases.start_offset() * s_biases.dtype().size_in_bytes(); + let off_x = self.x_layout.start_offset() * self.x_storage.dtype().size_in_bytes(); + + candle_metal_kernels::utils::set_param(&encoder, 0, (s_packed.buffer(), off_packed)); + candle_metal_kernels::utils::set_param(&encoder, 1, (s_scales.buffer(), off_scales)); + candle_metal_kernels::utils::set_param(&encoder, 2, (s_biases.buffer(), off_biases)); + candle_metal_kernels::utils::set_param(&encoder, 3, (self.x_storage.buffer(), off_x)); + candle_metal_kernels::utils::set_param(&encoder, 4, (&*output, 0usize)); + candle_metal_kernels::utils::set_param(&encoder, 5, self.m); + candle_metal_kernels::utils::set_param(&encoder, 6, self.in_features); + candle_metal_kernels::utils::set_param(&encoder, 7, self.out_features); + candle_metal_kernels::utils::set_param(&encoder, 8, self.group_size); + candle_metal_kernels::utils::set_param(&encoder, 9, self.num_groups); + + let (grid, group) = if self.m == 1 { + let row_groups = (self.out_features as usize).div_ceil(8); + ( + objc2_metal::MTLSize { + width: row_groups * 64, + height: 1, + depth: 1, + }, + objc2_metal::MTLSize { + width: 64, + height: 1, + depth: 1, + }, + ) + } else { + let tile_m = 8usize; + let tile_n = 8usize; + ( + objc2_metal::MTLSize { + width: (self.out_features as usize).div_ceil(tile_n) * tile_n, + height: (self.m as usize).div_ceil(tile_m) * tile_m, + depth: 1, + }, + objc2_metal::MTLSize { + width: tile_n, + height: tile_m, + depth: 1, + }, + ) + }; + encoder.dispatch_threads(grid, group); + drop(encoder); + let storage = candle_core::MetalStorage::new(output, device.clone(), out_el, DType::F16); + let shape = Shape::from(vec![self.m as usize, self.out_features as usize]); + Ok((storage, shape)) + } +} + +/// Perform fused 4-bit dequant + matmul on Metal. +/// +/// Given packed 4-bit weights, per-group scales/biases, and F16 activations, +/// computes `output = x @ dequant(packed, scales, biases)^T` without ever +/// materializing the full F16 weight matrix. +/// +/// Dispatches to: +/// - `q4_matvec_f16` for `M == 1` (decode hot path) +/// - `q4_matmul_tiled_f16` for `M > 1` (prefill / batched path) +/// +/// # Arguments +/// * `packed` - (out_features, in_features/8) U32 tensor on Metal +/// * `scales` - (out_features, num_groups) F16 tensor on Metal +/// * `biases` - (out_features, num_groups) F16 tensor on Metal +/// * `x` - (M, in_features) F16 tensor on Metal +/// * `group_size` - quantization group size +/// +/// # Returns +/// (M, out_features) F16 tensor on Metal +#[cfg(feature = "metal")] +pub fn q4_matmul_f16( + packed: &Tensor, + scales: &Tensor, + biases: &Tensor, + x: &Tensor, + group_size: usize, +) -> Result { + let x = x.contiguous()?; + let x_dims = x.dims(); + let m = if x_dims.len() == 2 { + x_dims[0] + } else if x_dims.len() == 1 { + 1 + } else { + candle_core::bail!("q4_matmul_f16: x must be 1D or 2D, got {:?}", x_dims); + }; + let in_features = *x_dims + .last() + .ok_or_else(|| candle_core::Error::Msg("q4_matmul_f16: empty x shape".into()))?; + let packed_dims = packed.dims(); + let out_features = packed_dims[0]; + let num_groups = in_features / group_size; + + // Extract Metal storage for x (captured in the op struct) + let (x_stor, x_lay) = x.storage_and_layout(); + let x_metal = match &*x_stor { + candle_core::Storage::Metal(ms) => ms.clone(), + _ => candle_core::bail!("q4_matmul_f16: x must be on Metal device"), + }; + + let op = MetalQ4MatmulF16 { + x_storage: x_metal, + x_layout: x_lay.clone(), + m: m as u32, + in_features: in_features as u32, + out_features: out_features as u32, + group_size: group_size as u32, + num_groups: num_groups as u32, + }; + + packed.apply_op3_no_bwd(scales, biases, &op) +} + // ─── MetalBackend ──────────────────────────────────────────────────── #[derive(Debug)] @@ -717,13 +1536,47 @@ impl MetalBackend { } Self { device } } + + fn maybe_log_large_fusion_memory(&self, stage: &str, approx_bytes: usize) { + const LARGE_FUSION_THRESHOLD_BYTES: usize = 128 * 1024 * 1024; + if approx_bytes < LARGE_FUSION_THRESHOLD_BYTES { + return; + } + + if let Some(mem) = memory_stats::memory_stats() { + log::info!( + "{} — rss={}", + stage, + human_bytes::human_bytes(mem.physical_mem as f64) + ); + } + } + + fn full_row_kernel_supported(&self, kernel_name: &'static str, width: usize) -> Result { + let Device::Metal(ref metal_dev) = self.device else { + return Ok(false); + }; + let pipeline = PIPELINE_CACHE.get_or_create(metal_dev, kernel_name)?; + Ok(width <= pipeline.max_total_threads_per_threadgroup()) + } } impl ComputeBackend for MetalBackend { - fn name(&self) -> &str { "metal" } - fn device(&self) -> &Device { &self.device } + fn name(&self) -> &str { + "metal" + } + fn device(&self) -> &Device { + &self.device + } - fn attention(&self, q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, causal: bool) -> Result { + fn attention( + &self, + q: &Tensor, + k: &Tensor, + v: &Tensor, + scale: f32, + causal: bool, + ) -> Result { let q_dims = q.dims(); // Generation case: seq_len=1 → use fused MSL kernel (causal is trivially satisfied for 1 query) if q_dims.len() == 4 && q_dims[2] == 1 && matches!(q.dtype(), DType::F16 | DType::F32) { @@ -734,15 +1587,30 @@ impl ComputeBackend for MetalBackend { // Flatten Q: (batch, heads, 1, head_dim) → (batch*heads, head_dim) let q_flat = q.contiguous()?.reshape((batch * heads, head_dim))?; // Flatten K/V: (batch, kv_heads, kv_len, head_dim) → (batch*kv_heads, kv_len, head_dim) - let k_flat = k.contiguous()?.reshape((batch * kv_heads, k_dims[2], head_dim))?; - let v_flat = v.contiguous()?.reshape((batch * kv_heads, k_dims[2], head_dim))?; - let out = q_flat.apply_op3_no_bwd(&k_flat, &v_flat, &MetalFusedVectorAttention { scale, gqa_ratio })?; + let k_flat = k + .contiguous()? + .reshape((batch * kv_heads, k_dims[2], head_dim))?; + let v_flat = v + .contiguous()? + .reshape((batch * kv_heads, k_dims[2], head_dim))?; + let out = q_flat.apply_op3_no_bwd( + &k_flat, + &v_flat, + &MetalFusedVectorAttention { scale, gqa_ratio }, + )?; return out.reshape((batch, heads, 1, head_dim)); } - // Promote to F32 if needed (F16 SDPA produces imprecise results on Metal) - let q = q.to_dtype(DType::F32)?; // no-op if already F32 - let k = k.to_dtype(DType::F32)?; - let v = v.to_dtype(DType::F32)?; + // Keep native F16/F32 attention on Metal. Only BF16 promotes because + // Metal SDPA does not support BF16 inputs natively. + let (q, k, v) = if matches!(q.dtype(), DType::BF16) { + ( + q.to_dtype(DType::F32)?, + k.to_dtype(DType::F32)?, + v.to_dtype(DType::F32)?, + ) + } else { + (q.clone(), k.clone(), v.clone()) + }; // Try fused SDPA first, fall back to manual attention if threadgroup memory exceeded match candle_nn::ops::sdpa(&q, &k, &v, None, causal, scale, 1.0) { Ok(result) => Ok(result), @@ -763,10 +1631,20 @@ impl ComputeBackend for MetalBackend { } } - fn sdpa(&self, q: &Tensor, k: &Tensor, v: &Tensor, mask: Option<&Tensor>, causal: bool, scale: f32) -> Result { + fn sdpa( + &self, + q: &Tensor, + k: &Tensor, + v: &Tensor, + mask: Option<&Tensor>, + causal: bool, + scale: f32, + ) -> Result { let q_dims = q.dims(); // Generation case: seq_len=1, no mask → fused MSL kernel (avoids SDPA overhead) - if q_dims.len() == 4 && q_dims[2] == 1 && mask.is_none() + if q_dims.len() == 4 + && q_dims[2] == 1 + && mask.is_none() && matches!(q.dtype(), DType::F16 | DType::F32) { let (batch, heads, _, head_dim) = (q_dims[0], q_dims[1], q_dims[2], q_dims[3]); @@ -774,9 +1652,17 @@ impl ComputeBackend for MetalBackend { let kv_heads = k_dims[1]; let gqa_ratio = (heads / kv_heads) as u32; let q_flat = q.contiguous()?.reshape((batch * heads, head_dim))?; - let k_flat = k.contiguous()?.reshape((batch * kv_heads, k_dims[2], head_dim))?; - let v_flat = v.contiguous()?.reshape((batch * kv_heads, k_dims[2], head_dim))?; - let out = q_flat.apply_op3_no_bwd(&k_flat, &v_flat, &MetalFusedVectorAttention { scale, gqa_ratio })?; + let k_flat = k + .contiguous()? + .reshape((batch * kv_heads, k_dims[2], head_dim))?; + let v_flat = v + .contiguous()? + .reshape((batch * kv_heads, k_dims[2], head_dim))?; + let out = q_flat.apply_op3_no_bwd( + &k_flat, + &v_flat, + &MetalFusedVectorAttention { scale, gqa_ratio }, + )?; return out.reshape((batch, heads, 1, head_dim)); } // Default: candle's SDPA @@ -791,6 +1677,32 @@ impl ComputeBackend for MetalBackend { weight.t()?.contiguous() } + fn preprocess_linear_weights(&self, weights: &[&Tensor]) -> Result { + let approx_bytes = weights + .iter() + .map(|weight| weight.shape().elem_count() * weight.dtype().size_in_bytes()) + .sum::(); + self.maybe_log_large_fusion_memory("metal fused linear load: start", approx_bytes); + + let mut processed = Vec::with_capacity(weights.len()); + for (idx, weight) in weights.iter().enumerate() { + processed.push(self.preprocess_linear_weight(weight)?); + self.maybe_log_large_fusion_memory( + &format!( + "metal fused linear load: preprocessed part {}/{}", + idx + 1, + weights.len() + ), + approx_bytes, + ); + } + + let refs: Vec<&Tensor> = processed.iter().collect(); + let fused = Tensor::cat(&refs, 1)?; + self.maybe_log_large_fusion_memory("metal fused linear load: fused", approx_bytes); + Ok(fused) + } + fn linear_forward(&self, x: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> Result { // Weight is pre-transposed by preprocess_linear_weight: shape (in_features, out_features). // No t() needed — just matmul directly. @@ -823,6 +1735,31 @@ impl ComputeBackend for MetalBackend { } } + // ── Fused 4-bit quantized matmul ──────────────────────────���────── + + fn q4_linear_forward( + &self, + packed: &Tensor, + scales: &Tensor, + biases: &Tensor, + x: &Tensor, + group_size: usize, + ) -> Result { + // Handle batched inputs: reshape 3D (batch, seq, features) → 2D, dispatch, reshape back. + let x_dims = x.dims(); + match x_dims { + [b, s, _k] => { + let b = *b; + let s = *s; + let x2d = x.reshape((b * s, ()))?; + let out2d = q4_matmul_f16(packed, scales, biases, &x2d, group_size)?; + let out_features = out2d.dim(1)?; + out2d.reshape((b, s, out_features)) + } + _ => q4_matmul_f16(packed, scales, biases, x, group_size), + } + } + // ── MSL-accelerated ops (validated by Metal vs CPU tests) ──────── fn gelu(&self, x: &Tensor) -> Result { @@ -877,16 +1814,41 @@ impl ComputeBackend for MetalBackend { } } - fn depthwise_conv1d_silu(&self, window: &Tensor, weight: &Tensor, _kernel_size: usize, _channels: usize) -> Result { + fn depthwise_conv1d_silu( + &self, + window: &Tensor, + weight: &Tensor, + _kernel_size: usize, + _channels: usize, + ) -> Result { window.apply_op2_no_bwd(weight, &MetalDepthwiseConv1dSilu) } - fn depthwise_conv1d_bias(&self, padded_input: &Tensor, weight: &Tensor, bias: &Tensor, _kernel_size: usize, _channels: usize) -> Result { - let weight_flat = if weight.dims().len() == 3 { weight.contiguous()?.flatten(1, 2)? } else { weight.contiguous()? }; + fn depthwise_conv1d_bias( + &self, + padded_input: &Tensor, + weight: &Tensor, + bias: &Tensor, + _kernel_size: usize, + _channels: usize, + ) -> Result { + let weight_flat = if weight.dims().len() == 3 { + weight.contiguous()?.flatten(1, 2)? + } else { + weight.contiguous()? + }; padded_input.apply_op3_no_bwd(&weight_flat, bias, &MetalDepthwiseConv1dBias) } - fn depthwise_conv1d_bias_ctx(&self, ctx: &Tensor, input: &Tensor, weight: &Tensor, bias: &Tensor, kernel_size: usize, channels: usize) -> Result { + fn depthwise_conv1d_bias_ctx( + &self, + ctx: &Tensor, + input: &Tensor, + weight: &Tensor, + bias: &Tensor, + kernel_size: usize, + channels: usize, + ) -> Result { let merged = Tensor::cat(&[ctx, input], 2)?; self.depthwise_conv1d_bias(&merged, weight, bias, kernel_size, channels) } @@ -894,7 +1856,11 @@ impl ComputeBackend for MetalBackend { // ── MSL-accelerated normalization ────────────────────────────────── fn rope(&self, x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { - if matches!(x.dtype(), DType::F32 | DType::F16) && x.is_contiguous() && cos.is_contiguous() && sin.is_contiguous() { + if matches!(x.dtype(), DType::F32 | DType::F16) + && x.is_contiguous() + && cos.is_contiguous() + && sin.is_contiguous() + { let x_dims = x.dims(); if x_dims.len() == 4 { let (_b, _h, seq_len, head_dim) = (x_dims[0], x_dims[1], x_dims[2], x_dims[3]); @@ -907,8 +1873,16 @@ impl ComputeBackend for MetalBackend { let cos_flat = cos.reshape(((), half_dim))?; let sin_flat = sin.reshape(((), half_dim))?; // Narrow to actual seq_len if cos has more positions - let cos_narrow = if cos_flat.dim(0)? > seq_len { cos_flat.narrow(0, 0, seq_len)? } else { cos_flat }; - let sin_narrow = if sin_flat.dim(0)? > seq_len { sin_flat.narrow(0, 0, seq_len)? } else { sin_flat }; + let cos_narrow = if cos_flat.dim(0)? > seq_len { + cos_flat.narrow(0, 0, seq_len)? + } else { + cos_flat + }; + let sin_narrow = if sin_flat.dim(0)? > seq_len { + sin_flat.narrow(0, 0, seq_len)? + } else { + sin_flat + }; let cos_c = cos_narrow.contiguous()?; let sin_c = sin_narrow.contiguous()?; return x.apply_op3_no_bwd(&cos_c, &sin_c, &MetalRope); @@ -921,32 +1895,61 @@ impl ComputeBackend for MetalBackend { fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result { let x = x.contiguous()?; + let hidden = x.dim(candle_core::D::Minus1)?; + let kernel_name = match x.dtype() { + DType::F32 => "rms_norm_f32", + DType::F16 => "rms_norm_f16", + _ => return candle_nn::ops::rms_norm(&x, weight, eps), + }; + if !self.full_row_kernel_supported(kernel_name, hidden)? { + return candle_nn::ops::rms_norm(&x, weight, eps); + } x.apply_op2_no_bwd(weight, &MetalRmsNorm { eps }) } - fn layer_norm(&self, x: &Tensor, weight: &Tensor, bias: Option<&Tensor>, eps: f32) -> Result { + fn layer_norm( + &self, + x: &Tensor, + weight: &Tensor, + bias: Option<&Tensor>, + eps: f32, + ) -> Result { if let Some(b) = bias { if x.is_contiguous() && matches!(x.dtype(), DType::F32 | DType::F16) { let x = x.contiguous()?; let b = b.contiguous()?; - if let Device::Metal(_) = x.device() { + let kernel_name = match x.dtype() { + DType::F32 => "layer_norm_f32", + DType::F16 => "layer_norm_f16", + _ => unreachable!(), + }; + if self.full_row_kernel_supported(kernel_name, x.dim(D::Minus1)?)? + && matches!(x.device(), Device::Metal(_)) + { let (bias_storage, bias_layout) = b.storage_and_layout(); if let candle_core::Storage::Metal(ms) = &*bias_storage { - let op = MetalLayerNorm { eps, bias_storage: ms.clone(), bias_layout: bias_layout.clone() }; + let op = MetalLayerNorm { + eps, + bias_storage: ms.clone(), + bias_layout: bias_layout.clone(), + }; return x.apply_op2_no_bwd(weight, &op); } } } } // Fallback to default implementation - use candle_core::{DType as D2, D}; + use candle_core::{D, DType as D2}; if x.is_contiguous() { if let Some(b) = bias { return candle_nn::ops::layer_norm(x, weight, b, eps); } } let x_dtype = x.dtype(); - let internal_dtype = match x_dtype { D2::F16 | D2::BF16 => D2::F32, d => d }; + let internal_dtype = match x_dtype { + D2::F16 | D2::BF16 => D2::F32, + d => d, + }; let hidden_size = x.dim(D::Minus1)?; let x = x.to_dtype(internal_dtype)?; let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?; @@ -954,18 +1957,55 @@ impl ComputeBackend for MetalBackend { let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?; let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(weight)?; - match bias { Some(b) => x.broadcast_add(b), None => Ok(x) } + match bias { + Some(b) => x.broadcast_add(b), + None => Ok(x), + } } fn rms_norm_gated(&self, x: &Tensor, z: &Tensor, weight: &Tensor, eps: f32) -> Result { let x = x.contiguous()?; let z = z.contiguous()?.to_dtype(x.dtype())?; + let hidden = x.dim(candle_core::D::Minus1)?; + let kernel_name = match x.dtype() { + DType::F32 => "rms_norm_gated_f32", + DType::F16 => "rms_norm_gated_f16", + _ => { + let n = candle_nn::ops::rms_norm(&x, weight, eps)?; + return n * candle_nn::ops::silu(&z)?; + } + }; + if !self.full_row_kernel_supported(kernel_name, hidden)? { + let n = candle_nn::ops::rms_norm(&x, weight, eps)?; + return n * candle_nn::ops::silu(&z)?; + } x.apply_op3_no_bwd(&z, weight, &MetalRmsNormGated { eps }) } - fn add_rms_norm(&self, a: &Tensor, b: &Tensor, weight: &Tensor, eps: f32) -> Result<(Tensor, Tensor)> { + fn add_rms_norm( + &self, + a: &Tensor, + b: &Tensor, + weight: &Tensor, + eps: f32, + ) -> Result<(Tensor, Tensor)> { let a = a.contiguous()?; let b = b.contiguous()?; + let hidden = a.dim(candle_core::D::Minus1)?; + let kernel_name = match a.dtype() { + DType::F32 => "add_rms_norm_f32", + DType::F16 => "add_rms_norm_f16", + _ => { + let res = (&a + &b)?; + let normed = candle_nn::ops::rms_norm(&res, weight, eps)?; + return Ok((res, normed)); + } + }; + if !self.full_row_kernel_supported(kernel_name, hidden)? { + let res = (&a + &b)?; + let normed = candle_nn::ops::rms_norm(&res, weight, eps)?; + return Ok((res, normed)); + } let shape = a.shape().clone(); let el = shape.elem_count(); let packed = a.apply_op3_no_bwd(&b, weight, &MetalAddRmsNorm { eps })?; @@ -979,7 +2019,14 @@ impl ComputeBackend for MetalBackend { x.apply_op2_no_bwd(weight, &MetalRmsNormChannel { eps }) } - fn adaln_modulate(&self, x: &Tensor, norm_weight: &Tensor, scale: &Tensor, shift: &Tensor, eps: f32) -> Result { + fn adaln_modulate( + &self, + x: &Tensor, + norm_weight: &Tensor, + scale: &Tensor, + shift: &Tensor, + eps: f32, + ) -> Result { let x = x.contiguous()?; let scale = scale.contiguous()?; let shift = shift.contiguous()?; @@ -987,7 +2034,11 @@ impl ComputeBackend for MetalBackend { if let Device::Metal(_) = x.device() { let (shift_storage, shift_layout) = shift.storage_and_layout(); if let candle_core::Storage::Metal(ms) = &*shift_storage { - let op = MetalAdalnModulate { eps, shift_storage: ms.clone(), shift_layout: shift_layout.clone() }; + let op = MetalAdalnModulate { + eps, + shift_storage: ms.clone(), + shift_layout: shift_layout.clone(), + }; return x.apply_op3_no_bwd(norm_weight, &scale, &op); } } @@ -997,20 +2048,31 @@ impl ComputeBackend for MetalBackend { } fn f8e4m3_to_f32(&self, x: &Tensor) -> Result { - if x.dtype() != DType::F8E4M3 { return x.to_dtype(DType::F32); } + if x.dtype() != DType::F8E4M3 { + return x.to_dtype(DType::F32); + } x.apply_op1_no_bwd(&MetalF8ToF32) } fn f8e4m3_to_f16(&self, x: &Tensor) -> Result { - if x.dtype() != DType::F8E4M3 { return x.to_dtype(DType::F16); } + if x.dtype() != DType::F8E4M3 { + return x.to_dtype(DType::F16); + } x.apply_op1_no_bwd(&MetalF8ToF16) } fn f8e4m3_to_bf16(&self, x: &Tensor) -> Result { - if x.dtype() != DType::F8E4M3 { return x.to_dtype(DType::BF16); } + if x.dtype() != DType::F8E4M3 { + return x.to_dtype(DType::BF16); + } let dev = x.device().clone(); - x.to_device(&Device::Cpu)?.to_dtype(DType::F32)?.to_dtype(DType::BF16)?.to_device(&dev) + x.to_device(&Device::Cpu)? + .to_dtype(DType::F32)? + .to_dtype(DType::BF16)? + .to_device(&dev) } - fn synchronize(&self) -> Result<()> { self.device.synchronize() } + fn synchronize(&self) -> Result<()> { + self.device.synchronize() + } } diff --git a/cake-core/src/backends/metal/ops.msl b/cake-core/src/backends/metal/ops.msl index 91743ee..2ff3783 100644 --- a/cake-core/src/backends/metal/ops.msl +++ b/cake-core/src/backends/metal/ops.msl @@ -839,3 +839,365 @@ kernel void fused_vector_attention_f32( output[bh * head_dim + d] = acc * (1.0f / sum_exp); } +// ─── q4 fused kernels ────────────────────────────────────────────── +// MLX affine 4-bit layout: +// packed: (out_features, in_features/8) uint32, 8 nibbles/U32, LSB-first +// scales: (out_features, num_groups) half +// biases: (out_features, num_groups) half +// x: (M, in_features) half +// output: (M, out_features) half +// +// Dequant: weight = nibble * scale + bias + +#define Q4_MATVEC_THREADS 64u +#define Q4_MATVEC_ROWS_PER_SIMDGROUP 4u +#define Q4_MATVEC_ROWS_PER_GROUP 8u + +#define Q4_TILE_BM 8u +#define Q4_TILE_BN 8u +#define Q4_TILE_BK 32u +#define Q4_TILE_BK_PAD (Q4_TILE_BK + 8u) + +inline float q4_dequant_f32( + uint packed_val, + uint nibble_idx, + half scale, + half bias +) { + float q = float((packed_val >> (nibble_idx * 4u)) & 0xFu); + return fma(q, float(scale), float(bias)); +} + +// Matrix-vector hot path: M == 1 (autoregressive decode). +// One 64-thread threadgroup handles 8 output rows; each simdgroup reduces 4 rows. +kernel void q4_matvec_f16( + device const uint* packed [[buffer(0)]], + device const half* scales [[buffer(1)]], + device const half* biases [[buffer(2)]], + device const half* x [[buffer(3)]], + device half* output [[buffer(4)]], + constant uint& M [[buffer(5)]], + constant uint& in_features [[buffer(6)]], + constant uint& out_features [[buffer(7)]], + constant uint& group_size [[buffer(8)]], + constant uint& num_groups [[buffer(9)]], + uint tid [[thread_index_in_threadgroup]], + uint lane [[thread_index_in_simdgroup]], + uint simd_id [[simdgroup_index_in_threadgroup]], + uint3 tgid [[threadgroup_position_in_grid]] +) { + if (M == 0u || tgid.x * Q4_MATVEC_ROWS_PER_GROUP >= out_features) return; + + uint packed_cols = in_features / 8u; + uint packed_group_shift = (group_size == 32u) ? 2u : ((group_size == 64u) ? 3u : 4u); + uint full_packed_cols = packed_cols & ~3u; + uint row_base = tgid.x * Q4_MATVEC_ROWS_PER_GROUP + simd_id * Q4_MATVEC_ROWS_PER_SIMDGROUP; + uint row0 = row_base + 0u; + uint row1 = row_base + 1u; + uint row2 = row_base + 2u; + uint row3 = row_base + 3u; + bool row0_active = row0 < out_features; + bool row1_active = row1 < out_features; + bool row2_active = row2 < out_features; + bool row3_active = row3 < out_features; + uint row0_offset = row0 * packed_cols; + uint row1_offset = row1 * packed_cols; + uint row2_offset = row2 * packed_cols; + uint row3_offset = row3 * packed_cols; + uint scale_base0 = row0 * num_groups; + uint scale_base1 = row1 * num_groups; + uint scale_base2 = row2 * num_groups; + uint scale_base3 = row3 * num_groups; + + float acc0 = 0.0f; + float acc1 = 0.0f; + float acc2 = 0.0f; + float acc3 = 0.0f; + + uint cached_group = 0xFFFFFFFFu; + float4 scale_vec0 = float4(0.0f); + float4 scale_vec1 = float4(0.0f); + float4 scale_vec2 = float4(0.0f); + float4 scale_vec3 = float4(0.0f); + float4 bias_vec0 = float4(0.0f); + float4 bias_vec1 = float4(0.0f); + float4 bias_vec2 = float4(0.0f); + float4 bias_vec3 = float4(0.0f); + + for (uint pc_block = lane * 4u; pc_block < full_packed_cols; pc_block += 128u) { + uint group_idx = pc_block >> packed_group_shift; + if (group_idx != cached_group) { + cached_group = group_idx; + if (row0_active) { + scale_vec0 = float4(float(scales[scale_base0 + group_idx])); + bias_vec0 = float4(float(biases[scale_base0 + group_idx])); + } + if (row1_active) { + scale_vec1 = float4(float(scales[scale_base1 + group_idx])); + bias_vec1 = float4(float(biases[scale_base1 + group_idx])); + } + if (row2_active) { + scale_vec2 = float4(float(scales[scale_base2 + group_idx])); + bias_vec2 = float4(float(biases[scale_base2 + group_idx])); + } + if (row3_active) { + scale_vec3 = float4(float(scales[scale_base3 + group_idx])); + bias_vec3 = float4(float(biases[scale_base3 + group_idx])); + } + } + + uint j_base = pc_block * 8u; + float4 x0 = float4(float(x[j_base + 0u]), float(x[j_base + 8u]), float(x[j_base + 16u]), float(x[j_base + 24u])); + float4 x1 = float4(float(x[j_base + 1u]), float(x[j_base + 9u]), float(x[j_base + 17u]), float(x[j_base + 25u])); + float4 x2 = float4(float(x[j_base + 2u]), float(x[j_base + 10u]), float(x[j_base + 18u]), float(x[j_base + 26u])); + float4 x3 = float4(float(x[j_base + 3u]), float(x[j_base + 11u]), float(x[j_base + 19u]), float(x[j_base + 27u])); + float4 x4 = float4(float(x[j_base + 4u]), float(x[j_base + 12u]), float(x[j_base + 20u]), float(x[j_base + 28u])); + float4 x5 = float4(float(x[j_base + 5u]), float(x[j_base + 13u]), float(x[j_base + 21u]), float(x[j_base + 29u])); + float4 x6 = float4(float(x[j_base + 6u]), float(x[j_base + 14u]), float(x[j_base + 22u]), float(x[j_base + 30u])); + float4 x7 = float4(float(x[j_base + 7u]), float(x[j_base + 15u]), float(x[j_base + 23u]), float(x[j_base + 31u])); + + if (row0_active) { + uint4 packed_row = uint4( + packed[row0_offset + pc_block + 0u], + packed[row0_offset + pc_block + 1u], + packed[row0_offset + pc_block + 2u], + packed[row0_offset + pc_block + 3u] + ); + acc0 += dot(fma(float4(packed_row & 0xFu), scale_vec0, bias_vec0), x0); + acc0 += dot(fma(float4((packed_row >> 4u) & 0xFu), scale_vec0, bias_vec0), x1); + acc0 += dot(fma(float4((packed_row >> 8u) & 0xFu), scale_vec0, bias_vec0), x2); + acc0 += dot(fma(float4((packed_row >> 12u) & 0xFu), scale_vec0, bias_vec0), x3); + acc0 += dot(fma(float4((packed_row >> 16u) & 0xFu), scale_vec0, bias_vec0), x4); + acc0 += dot(fma(float4((packed_row >> 20u) & 0xFu), scale_vec0, bias_vec0), x5); + acc0 += dot(fma(float4((packed_row >> 24u) & 0xFu), scale_vec0, bias_vec0), x6); + acc0 += dot(fma(float4((packed_row >> 28u) & 0xFu), scale_vec0, bias_vec0), x7); + } + if (row1_active) { + uint4 packed_row = uint4( + packed[row1_offset + pc_block + 0u], + packed[row1_offset + pc_block + 1u], + packed[row1_offset + pc_block + 2u], + packed[row1_offset + pc_block + 3u] + ); + acc1 += dot(fma(float4(packed_row & 0xFu), scale_vec1, bias_vec1), x0); + acc1 += dot(fma(float4((packed_row >> 4u) & 0xFu), scale_vec1, bias_vec1), x1); + acc1 += dot(fma(float4((packed_row >> 8u) & 0xFu), scale_vec1, bias_vec1), x2); + acc1 += dot(fma(float4((packed_row >> 12u) & 0xFu), scale_vec1, bias_vec1), x3); + acc1 += dot(fma(float4((packed_row >> 16u) & 0xFu), scale_vec1, bias_vec1), x4); + acc1 += dot(fma(float4((packed_row >> 20u) & 0xFu), scale_vec1, bias_vec1), x5); + acc1 += dot(fma(float4((packed_row >> 24u) & 0xFu), scale_vec1, bias_vec1), x6); + acc1 += dot(fma(float4((packed_row >> 28u) & 0xFu), scale_vec1, bias_vec1), x7); + } + if (row2_active) { + uint4 packed_row = uint4( + packed[row2_offset + pc_block + 0u], + packed[row2_offset + pc_block + 1u], + packed[row2_offset + pc_block + 2u], + packed[row2_offset + pc_block + 3u] + ); + acc2 += dot(fma(float4(packed_row & 0xFu), scale_vec2, bias_vec2), x0); + acc2 += dot(fma(float4((packed_row >> 4u) & 0xFu), scale_vec2, bias_vec2), x1); + acc2 += dot(fma(float4((packed_row >> 8u) & 0xFu), scale_vec2, bias_vec2), x2); + acc2 += dot(fma(float4((packed_row >> 12u) & 0xFu), scale_vec2, bias_vec2), x3); + acc2 += dot(fma(float4((packed_row >> 16u) & 0xFu), scale_vec2, bias_vec2), x4); + acc2 += dot(fma(float4((packed_row >> 20u) & 0xFu), scale_vec2, bias_vec2), x5); + acc2 += dot(fma(float4((packed_row >> 24u) & 0xFu), scale_vec2, bias_vec2), x6); + acc2 += dot(fma(float4((packed_row >> 28u) & 0xFu), scale_vec2, bias_vec2), x7); + } + if (row3_active) { + uint4 packed_row = uint4( + packed[row3_offset + pc_block + 0u], + packed[row3_offset + pc_block + 1u], + packed[row3_offset + pc_block + 2u], + packed[row3_offset + pc_block + 3u] + ); + acc3 += dot(fma(float4(packed_row & 0xFu), scale_vec3, bias_vec3), x0); + acc3 += dot(fma(float4((packed_row >> 4u) & 0xFu), scale_vec3, bias_vec3), x1); + acc3 += dot(fma(float4((packed_row >> 8u) & 0xFu), scale_vec3, bias_vec3), x2); + acc3 += dot(fma(float4((packed_row >> 12u) & 0xFu), scale_vec3, bias_vec3), x3); + acc3 += dot(fma(float4((packed_row >> 16u) & 0xFu), scale_vec3, bias_vec3), x4); + acc3 += dot(fma(float4((packed_row >> 20u) & 0xFu), scale_vec3, bias_vec3), x5); + acc3 += dot(fma(float4((packed_row >> 24u) & 0xFu), scale_vec3, bias_vec3), x6); + acc3 += dot(fma(float4((packed_row >> 28u) & 0xFu), scale_vec3, bias_vec3), x7); + } + } + + for (uint pc = full_packed_cols + lane; pc < packed_cols; pc += 32u) { + uint group_idx = pc >> packed_group_shift; + if (group_idx != cached_group) { + cached_group = group_idx; + if (row0_active) { + scale_vec0 = float4(float(scales[scale_base0 + group_idx])); + bias_vec0 = float4(float(biases[scale_base0 + group_idx])); + } + if (row1_active) { + scale_vec1 = float4(float(scales[scale_base1 + group_idx])); + bias_vec1 = float4(float(biases[scale_base1 + group_idx])); + } + if (row2_active) { + scale_vec2 = float4(float(scales[scale_base2 + group_idx])); + bias_vec2 = float4(float(biases[scale_base2 + group_idx])); + } + if (row3_active) { + scale_vec3 = float4(float(scales[scale_base3 + group_idx])); + bias_vec3 = float4(float(biases[scale_base3 + group_idx])); + } + } + + uint j_base = pc * 8u; + float4 x_lo = float4(float(x[j_base + 0u]), float(x[j_base + 1u]), float(x[j_base + 2u]), float(x[j_base + 3u])); + float4 x_hi = float4(float(x[j_base + 4u]), float(x[j_base + 5u]), float(x[j_base + 6u]), float(x[j_base + 7u])); + + if (row0_active) { + uint packed_val = packed[row0_offset + pc]; + acc0 += dot(fma(float4( + float((packed_val >> 0u) & 0xFu), + float((packed_val >> 4u) & 0xFu), + float((packed_val >> 8u) & 0xFu), + float((packed_val >> 12u) & 0xFu) + ), scale_vec0, bias_vec0), x_lo); + acc0 += dot(fma(float4( + float((packed_val >> 16u) & 0xFu), + float((packed_val >> 20u) & 0xFu), + float((packed_val >> 24u) & 0xFu), + float((packed_val >> 28u) & 0xFu) + ), scale_vec0, bias_vec0), x_hi); + } + if (row1_active) { + uint packed_val = packed[row1_offset + pc]; + acc1 += dot(fma(float4( + float((packed_val >> 0u) & 0xFu), + float((packed_val >> 4u) & 0xFu), + float((packed_val >> 8u) & 0xFu), + float((packed_val >> 12u) & 0xFu) + ), scale_vec1, bias_vec1), x_lo); + acc1 += dot(fma(float4( + float((packed_val >> 16u) & 0xFu), + float((packed_val >> 20u) & 0xFu), + float((packed_val >> 24u) & 0xFu), + float((packed_val >> 28u) & 0xFu) + ), scale_vec1, bias_vec1), x_hi); + } + if (row2_active) { + uint packed_val = packed[row2_offset + pc]; + acc2 += dot(fma(float4( + float((packed_val >> 0u) & 0xFu), + float((packed_val >> 4u) & 0xFu), + float((packed_val >> 8u) & 0xFu), + float((packed_val >> 12u) & 0xFu) + ), scale_vec2, bias_vec2), x_lo); + acc2 += dot(fma(float4( + float((packed_val >> 16u) & 0xFu), + float((packed_val >> 20u) & 0xFu), + float((packed_val >> 24u) & 0xFu), + float((packed_val >> 28u) & 0xFu) + ), scale_vec2, bias_vec2), x_hi); + } + if (row3_active) { + uint packed_val = packed[row3_offset + pc]; + acc3 += dot(fma(float4( + float((packed_val >> 0u) & 0xFu), + float((packed_val >> 4u) & 0xFu), + float((packed_val >> 8u) & 0xFu), + float((packed_val >> 12u) & 0xFu) + ), scale_vec3, bias_vec3), x_lo); + acc3 += dot(fma(float4( + float((packed_val >> 16u) & 0xFu), + float((packed_val >> 20u) & 0xFu), + float((packed_val >> 24u) & 0xFu), + float((packed_val >> 28u) & 0xFu) + ), scale_vec3, bias_vec3), x_hi); + } + } + + acc0 = simd_sum(acc0); + acc1 = simd_sum(acc1); + acc2 = simd_sum(acc2); + acc3 = simd_sum(acc3); + + if (lane == 0u) { + if (row_base + 0u < out_features) output[row_base + 0u] = half(acc0); + if (row_base + 1u < out_features) output[row_base + 1u] = half(acc1); + if (row_base + 2u < out_features) output[row_base + 2u] = half(acc2); + if (row_base + 3u < out_features) output[row_base + 3u] = half(acc3); + } +} + +// Tiled matrix-matrix path: M > 1 (prefill / batched prompts). +// Dequantizes one BN x BK tile into threadgroup memory and reuses it across BM rows. +kernel void q4_matmul_tiled_f16( + device const uint* packed [[buffer(0)]], + device const half* scales [[buffer(1)]], + device const half* biases [[buffer(2)]], + device const half* x [[buffer(3)]], + device half* output [[buffer(4)]], + constant uint& M [[buffer(5)]], + constant uint& in_features [[buffer(6)]], + constant uint& out_features [[buffer(7)]], + constant uint& group_size [[buffer(8)]], + constant uint& num_groups [[buffer(9)]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 tgid [[threadgroup_position_in_grid]] +) { + threadgroup half x_tile[Q4_TILE_BM][Q4_TILE_BK]; + threadgroup half w_tile[Q4_TILE_BN][Q4_TILE_BK_PAD]; + + uint local_col = lid.x; + uint local_row = lid.y; + uint linear_tid = local_row * Q4_TILE_BN + local_col; + + uint row = tgid.y * Q4_TILE_BM + local_row; + uint col = tgid.x * Q4_TILE_BN + local_col; + uint packed_cols = in_features / 8u; + + float acc = 0.0f; + + for (uint k0 = 0u; k0 < in_features; k0 += Q4_TILE_BK) { + for (uint idx = linear_tid; idx < Q4_TILE_BM * Q4_TILE_BK; idx += Q4_TILE_BM * Q4_TILE_BN) { + uint tile_r = idx / Q4_TILE_BK; + uint tile_k = idx % Q4_TILE_BK; + uint global_row = tgid.y * Q4_TILE_BM + tile_r; + uint global_k = k0 + tile_k; + x_tile[tile_r][tile_k] = + (global_row < M && global_k < in_features) + ? x[global_row * in_features + global_k] + : half(0.0h); + } + + for (uint idx = linear_tid; idx < Q4_TILE_BN * Q4_TILE_BK; idx += Q4_TILE_BM * Q4_TILE_BN) { + uint tile_c = idx / Q4_TILE_BK; + uint tile_k = idx % Q4_TILE_BK; + uint global_col = tgid.x * Q4_TILE_BN + tile_c; + uint global_k = k0 + tile_k; + + if (global_col < out_features && global_k < in_features) { + uint pc = global_k / 8u; + uint nib = global_k & 7u; + uint group = global_k / group_size; + uint packed_val = packed[global_col * packed_cols + pc]; + w_tile[tile_c][tile_k] = half(q4_dequant_f32( + packed_val, + nib, + scales[global_col * num_groups + group], + biases[global_col * num_groups + group] + )); + } else { + w_tile[tile_c][tile_k] = half(0.0h); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (row < M && col < out_features) { + for (uint kk = 0u; kk < Q4_TILE_BK; kk++) { + acc = fma(float(x_tile[local_row][kk]), float(w_tile[local_col][kk]), acc); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (row < M && col < out_features) { + output[row * out_features + col] = half(acc); + } +} diff --git a/cake-core/src/backends/mod.rs b/cake-core/src/backends/mod.rs index 64adf71..b0f3146 100644 --- a/cake-core/src/backends/mod.rs +++ b/cake-core/src/backends/mod.rs @@ -22,6 +22,8 @@ pub use cuda::CudaBackend; mod metal; #[cfg(feature = "metal")] pub use self::metal::MetalBackend; +#[cfg(feature = "metal")] +pub use self::metal::q4_matmul_f16; #[cfg(feature = "vulkan")] mod vulkan; @@ -87,13 +89,7 @@ pub trait ComputeBackend: Send + Sync + std::fmt::Debug { // ── Fused normalization ────────────────────────────────────────── /// `rms_norm(x, weight, eps) * silu(z)` — GDN output gating. - fn rms_norm_gated( - &self, - x: &Tensor, - z: &Tensor, - weight: &Tensor, - eps: f32, - ) -> Result; + fn rms_norm_gated(&self, x: &Tensor, z: &Tensor, weight: &Tensor, eps: f32) -> Result; /// `rms_norm(a + b, weight, eps)` — residual + norm fusion. /// Returns `(residual, normed)` where `residual = a + b`. @@ -194,6 +190,20 @@ pub trait ComputeBackend: Send + Sync + std::fmt::Debug { Ok(weight.clone()) } + /// Pre-process and fuse multiple linear weights. + /// + /// Default behavior preserves the existing semantics: concatenate the + /// original `(out_features, in_features)` weights first, then run the + /// backend-specific preprocessing step once on the fused tensor. + /// + /// Backends that transpose or otherwise expand weights during + /// `preprocess_linear_weight` can override this to reduce peak memory by + /// preprocessing each part incrementally before concatenation. + fn preprocess_linear_weights(&self, weights: &[&Tensor]) -> Result { + let fused = Tensor::cat(weights, 0)?; + self.preprocess_linear_weight(&fused) + } + // ── Inference primitives ────────────────────────────────────────── /// Linear layer forward: `x @ weight^T + bias`. @@ -203,12 +213,7 @@ pub trait ComputeBackend: Send + Sync + std::fmt::Debug { /// (avoids slow broadcast_matmul on CUDA/CPU) /// - For non-contiguous 3D+: uses broadcast_left on weight /// - No dtype conversion (caller is responsible) - fn linear_forward( - &self, - x: &Tensor, - weight: &Tensor, - bias: Option<&Tensor>, - ) -> Result { + fn linear_forward(&self, x: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> Result { let out = match x.dims() { [b1, b2, m, k] => { if x.is_contiguous() { @@ -240,6 +245,27 @@ pub trait ComputeBackend: Send + Sync + std::fmt::Debug { } } + /// Fused 4-bit dequant + matmul: `output = x @ dequant(packed, scales, biases)^T`. + /// + /// On Metal, this dispatches to the q4_matmul_f16 MSL kernel, keeping weights + /// at 0.5 bytes/element (4x memory reduction vs F16). + /// + /// Default: dequantizes on CPU via `gptq::dequantize_packed_4bit` and calls + /// `linear_forward`. Suboptimal but correct — only Metal overrides this. + fn q4_linear_forward( + &self, + packed: &Tensor, + scales: &Tensor, + biases: &Tensor, + x: &Tensor, + group_size: usize, + ) -> Result { + // CPU fallback: dequantize to F32, convert to input dtype, matmul. + let weight = crate::utils::gptq::dequantize_packed_4bit(packed, scales, biases, group_size)?; + let weight = weight.to_dtype(x.dtype())?.to_device(x.device())?; + self.linear_forward(x, &weight, None) + } + /// RMS normalization: `x * weight / sqrt(mean(x^2) + eps)`. fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result { candle_nn::ops::rms_norm(x, weight, eps) @@ -285,9 +311,8 @@ pub trait ComputeBackend: Send + Sync + std::fmt::Debug { match &b_data { Some(bd) => { for i in 0..hidden { - out[off + i] = - (((row[i] as f64 - mean) * rstd) * w_data[i] as f64 - + bd[i] as f64) as f32; + out[off + i] = (((row[i] as f64 - mean) * rstd) * w_data[i] as f64 + + bd[i] as f64) as f32; } } None => { @@ -530,12 +555,7 @@ pub trait ComputeBackend: Send + Sync + std::fmt::Debug { /// Create a causal attention mask. Returns a U8 tensor of shape `(seq_len, kv_len)` /// where 1 = masked (future position), 0 = attend. /// Callers use `masked_fill` or `where_cond` to apply the mask. - fn causal_mask( - &self, - seq_len: usize, - kv_len: usize, - device: &Device, - ) -> Result { + fn causal_mask(&self, seq_len: usize, kv_len: usize, device: &Device) -> Result { if seq_len == 1 { return Tensor::zeros((1, kv_len), DType::U8, device); } diff --git a/cake-core/src/models/common/attention.rs b/cake-core/src/models/common/attention.rs index b65e95c..e54e79b 100644 --- a/cake-core/src/models/common/attention.rs +++ b/cake-core/src/models/common/attention.rs @@ -4,14 +4,16 @@ use std::sync::Arc; use candle_core::{DType, Result, Tensor, D}; use candle_nn::VarBuilder; -use crate::backends::ComputeBackend; use super::config::load_rms_norm_weight; +use super::mlp::try_load_quantized; +use crate::backends::ComputeBackend; +use crate::utils::quantized_linear::LinearWeight; #[derive(Debug, Clone)] pub struct CausalSelfAttention { - qkv_proj_weight: Tensor, + qkv_proj_weight: LinearWeight, qkv_proj_bias: Option, - o_proj_weight: Tensor, + o_proj_weight: LinearWeight, num_attention_heads: usize, num_key_value_heads: usize, head_dim: usize, @@ -61,14 +63,20 @@ impl CausalSelfAttention { } else { // Partial RoPE: apply only to first rotary_dim channels, pass the rest through. let x_rot = x.narrow(D::Minus1, 0, self.rotary_dim)?.contiguous()?; - let x_pass = x.narrow(D::Minus1, self.rotary_dim, self.head_dim - self.rotary_dim)?.contiguous()?; + let x_pass = x + .narrow(D::Minus1, self.rotary_dim, self.head_dim - self.rotary_dim)? + .contiguous()?; let x_rot = self.backend.rope(&x_rot, &cos, &sin)?; Tensor::cat(&[&x_rot, &x_pass], D::Minus1) } } /// Standard load — derives all flags from `cfg`. - pub fn load(vb: VarBuilder, cfg: &super::Config, backend: Arc) -> Result { + pub fn load( + vb: VarBuilder, + cfg: &super::Config, + backend: Arc, + ) -> Result { Self::load_custom(vb, cfg, cfg.use_qk_norm, cfg.sliding_window, true, backend) } @@ -82,22 +90,42 @@ impl CausalSelfAttention { backend: Arc, ) -> Result { let size_in = cfg.hidden_size; - let head_dim = cfg.head_dim.unwrap_or(cfg.hidden_size / cfg.num_attention_heads); + let head_dim = cfg + .head_dim + .unwrap_or(cfg.hidden_size / cfg.num_attention_heads); let rotary_dim = (head_dim as f32 * cfg.partial_rotary_factor) as usize; let size_q = head_dim * cfg.num_attention_heads; let size_kv = head_dim * cfg.num_key_value_heads; let (qkv_proj_weight, qkv_proj_bias) = if cfg.fused_qkv_proj { // Phi-3/4 style: weights already fused as a single 'qkv_proj' tensor. - let w = vb.pp("qkv_proj").get((size_q + 2 * size_kv, size_in), "weight")?; - let w = backend.preprocess_linear_weight(&w)?; + let vb_proj = vb.pp("qkv_proj"); + let w = if let Some(qw) = try_load_quantized(&vb_proj)? { + qw + } else { + let w = vb_proj.get((size_q + 2 * size_kv, size_in), "weight")?; + LinearWeight::Dense(backend.preprocess_linear_weight(&w)?) + }; (w, None) } else if cfg.use_qkv_bias { - let q_w = vb.pp("q_proj").get((size_q, size_in), "weight")?; - let k_w = vb.pp("k_proj").get((size_kv, size_in), "weight")?; - let v_w = vb.pp("v_proj").get((size_kv, size_in), "weight")?; - let fused_w = Tensor::cat(&[&q_w, &k_w, &v_w], 0)?; - let fused_w = backend.preprocess_linear_weight(&fused_w)?; + // Try quantized path for Q/K/V projections + let q_q = try_load_quantized(&vb.pp("q_proj"))?; + let k_q = try_load_quantized(&vb.pp("k_proj"))?; + let v_q = try_load_quantized(&vb.pp("v_proj"))?; + + let fused_w = match (q_q, k_q, v_q) { + (Some(q), Some(k), Some(v)) => { + super::mlp::fuse_quantized_pub(&[q, k, v])? + } + _ => { + let q_w = vb.pp("q_proj").get((size_q, size_in), "weight")?; + let k_w = vb.pp("k_proj").get((size_kv, size_in), "weight")?; + let v_w = vb.pp("v_proj").get((size_kv, size_in), "weight")?; + LinearWeight::Dense( + backend.preprocess_linear_weights(&[&q_w, &k_w, &v_w])?, + ) + } + }; let q_b = vb.pp("q_proj").get(size_q, "bias")?; let k_b = vb.pp("k_proj").get(size_kv, "bias")?; @@ -106,20 +134,47 @@ impl CausalSelfAttention { (fused_w, Some(fused_b)) } else { - let q_w = vb.pp("q_proj").get((size_q, size_in), "weight")?; - let k_w = vb.pp("k_proj").get((size_kv, size_in), "weight")?; - let v_w = vb.pp("v_proj").get((size_kv, size_in), "weight")?; - let fused_w = Tensor::cat(&[&q_w, &k_w, &v_w], 0)?; - let fused_w = backend.preprocess_linear_weight(&fused_w)?; + let q_q = try_load_quantized(&vb.pp("q_proj"))?; + let k_q = try_load_quantized(&vb.pp("k_proj"))?; + let v_q = try_load_quantized(&vb.pp("v_proj"))?; + + let fused_w = match (q_q, k_q, v_q) { + (Some(q), Some(k), Some(v)) => { + super::mlp::fuse_quantized_pub(&[q, k, v])? + } + _ => { + let q_w = vb.pp("q_proj").get((size_q, size_in), "weight")?; + let k_w = vb.pp("k_proj").get((size_kv, size_in), "weight")?; + let v_w = vb.pp("v_proj").get((size_kv, size_in), "weight")?; + LinearWeight::Dense( + backend.preprocess_linear_weights(&[&q_w, &k_w, &v_w])?, + ) + } + }; (fused_w, None) }; - let o_w = vb.pp("o_proj").get((size_in, size_q), "weight")?; - let o_proj_weight = backend.preprocess_linear_weight(&o_w)?; + let o_proj_weight = { + let vb_o = vb.pp("o_proj"); + if let Some(qw) = try_load_quantized(&vb_o)? { + qw + } else { + let w = vb_o.get((size_in, size_q), "weight")?; + LinearWeight::Dense(backend.preprocess_linear_weight(&w)?) + } + }; let (q_norm_weight, k_norm_weight) = if use_qk_norm { - let norm_dim = if cfg.pre_reshape_qk_norm { size_q } else { head_dim }; - let norm_kv_dim = if cfg.pre_reshape_qk_norm { size_kv } else { head_dim }; + let norm_dim = if cfg.pre_reshape_qk_norm { + size_q + } else { + head_dim + }; + let norm_kv_dim = if cfg.pre_reshape_qk_norm { + size_kv + } else { + head_dim + }; let residual = cfg.residual_rms_norm; let qn = load_rms_norm_weight(norm_dim, residual, vb.pp("q_norm"))?; let kn = load_rms_norm_weight(norm_kv_dim, residual, vb.pp("k_norm"))?; @@ -160,7 +215,8 @@ impl CausalSelfAttention { // Single fused QKV projection (routed through backend for GPU acceleration) let qkv = self - .backend.linear_forward(x, &self.qkv_proj_weight, self.qkv_proj_bias.as_ref()) + .qkv_proj_weight + .forward(x, self.qkv_proj_bias.as_ref(), &*self.backend) .map_err(|e| anyhow!("qkv.forward -> {e}"))?; let q = qkv @@ -176,43 +232,70 @@ impl CausalSelfAttention { // OLMo2-style: apply QK-norm BEFORE head reshape (norm dim = size_q/size_kv). let (q, k) = if self.pre_reshape_qk_norm { let q = if let Some(w) = &self.q_norm_weight { - self.backend.rms_norm(&q.contiguous() - .map_err(|e| anyhow!("pre_reshape q contiguous -> {e}"))?, w, self.qk_norm_eps) + self.backend + .rms_norm( + &q.contiguous() + .map_err(|e| anyhow!("pre_reshape q contiguous -> {e}"))?, + w, + self.qk_norm_eps, + ) .map_err(|e| anyhow!("pre_reshape q_norm -> {e}"))? - } else { q }; + } else { + q + }; let k = if let Some(w) = &self.k_norm_weight { - self.backend.rms_norm(&k.contiguous() - .map_err(|e| anyhow!("pre_reshape k contiguous -> {e}"))?, w, self.qk_norm_eps) + self.backend + .rms_norm( + &k.contiguous() + .map_err(|e| anyhow!("pre_reshape k contiguous -> {e}"))?, + w, + self.qk_norm_eps, + ) .map_err(|e| anyhow!("pre_reshape k_norm -> {e}"))? - } else { k }; + } else { + k + }; (q, k) } else { (q, k) }; // Reshape: (b, seq, heads, head_dim) - let q = q - .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?; - let k = k - .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?; - let v = v - .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?; + let q = q.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?; + let k = k.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?; + let v = v.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?; // Standard QK-norm: applied after reshape (on head_dim, last dim) before transpose. let q = if !self.pre_reshape_qk_norm { if let Some(w) = &self.q_norm_weight { - self.backend.rms_norm(&q.contiguous() - .map_err(|e| anyhow!("q contiguous -> {e}"))?, w, self.qk_norm_eps) + self.backend + .rms_norm( + &q.contiguous().map_err(|e| anyhow!("q contiguous -> {e}"))?, + w, + self.qk_norm_eps, + ) .map_err(|e| anyhow!("q_norm -> {e}"))? - } else { q } - } else { q }; + } else { + q + } + } else { + q + }; let k = if !self.pre_reshape_qk_norm { if let Some(w) = &self.k_norm_weight { - self.backend.rms_norm(&k.contiguous() - .map_err(|e| anyhow!("k contiguous -> {e}"))?, w, self.qk_norm_eps) + self.backend + .rms_norm( + &k.contiguous().map_err(|e| anyhow!("k contiguous -> {e}"))?, + w, + self.qk_norm_eps, + ) .map_err(|e| anyhow!("k_norm -> {e}"))? - } else { k } - } else { k }; + } else { + k + } + } else { + k + }; // Transpose to (b, heads, seq, head_dim). // For generation (seq_len=1), squeeze+unsqueeze avoids the contiguous @@ -220,18 +303,23 @@ impl CausalSelfAttention { // when the swapped dimension has size 1. let (q, k, v) = if seq_len == 1 { ( - q.squeeze(1)?.unsqueeze(2) + q.squeeze(1)? + .unsqueeze(2) .map_err(|e| anyhow!("q.squeeze/unsqueeze -> {e}"))?, - k.squeeze(1)?.unsqueeze(2) + k.squeeze(1)? + .unsqueeze(2) .map_err(|e| anyhow!("k.squeeze/unsqueeze -> {e}"))?, - v.squeeze(1)?.unsqueeze(2) + v.squeeze(1)? + .unsqueeze(2) .map_err(|e| anyhow!("v.squeeze/unsqueeze -> {e}"))?, ) } else { ( - q.transpose(1, 2)?.contiguous() + q.transpose(1, 2)? + .contiguous() .map_err(|e| anyhow!("q.transpose -> {e}"))?, - k.transpose(1, 2)?.contiguous() + k.transpose(1, 2)? + .contiguous() .map_err(|e| anyhow!("k.transpose -> {e}"))?, v.transpose(1, 2) .map_err(|e| anyhow!("v.transpose -> {e}"))?, @@ -273,34 +361,51 @@ impl CausalSelfAttention { { let scale = 1.0 / (self.head_dim as f32).sqrt(); break 'attn crate::utils::flash_attn::flash_attention( - &q, &k, &v, scale, seq_len > 1, - ).map_err(|e| anyhow!("flash_attn: {e}"))?; + &q, + &k, + &v, + scale, + seq_len > 1, + ) + .map_err(|e| anyhow!("flash_attn: {e}"))?; } // The actual kv seq_len (may differ from query seq_len with sliding window) let kv_seq_len = k.dims()[2]; - // Metal: mixed-precision attention (F16 matmuls + F32 softmax) - // Try SDPA first, fall back to manual if threadgroup memory exceeded + // Metal: keep native F16/F32 attention when possible to avoid doubling + // memory bandwidth on the fallback path. BF16 still promotes because + // Metal SDPA does not support it natively. #[cfg(feature = "metal")] if matches!(q.device(), candle_core::Device::Metal(_)) { let scale = 1.0 / (self.head_dim as f32).sqrt(); - // Try F32 SDPA (fastest when it works) - let q32 = q.to_dtype(DType::F32)?; - let k32 = k.to_dtype(DType::F32)?; - let v32 = v.to_dtype(DType::F32)?; - match self.backend.sdpa(&q32, &k32, &v32, None, seq_len > 1, scale) { - Ok(result) => break 'attn result, - Err(_) => { - // Fallback: mixed-precision manual attention + if matches!(q.dtype(), DType::BF16) { + let q32 = q.to_dtype(DType::F32)?; + let k32 = k.to_dtype(DType::F32)?; + let v32 = v.to_dtype(DType::F32)?; + if let Ok(result) = + self.backend + .sdpa(&q32, &k32, &v32, None, seq_len > 1, scale) + { + break 'attn result; } + } else if let Ok(result) = self.backend.sdpa(&q, &k, &v, None, seq_len > 1, scale) { + break 'attn result; } } - // Compute attention in F32 for numerical stability (CPU, Metal fallback) - let q = q.to_dtype(DType::F32)?; - let k = k.to_dtype(DType::F32)?; - let v = v.to_dtype(DType::F32)?; + // Compute fallback attention in native F16/F32 on Metal and in F32 elsewhere. + let (q, k, v) = if matches!(q.device(), candle_core::Device::Metal(_)) + && matches!(in_dtype, DType::F16 | DType::F32) + { + (q, k, v) + } else { + ( + q.to_dtype(DType::F32)?, + k.to_dtype(DType::F32)?, + v.to_dtype(DType::F32)?, + ) + }; // Manual attention with GQA head expansion (CPU fallback) let k = self @@ -351,7 +456,7 @@ impl CausalSelfAttention { } else { y.transpose(1, 2)?.reshape(&[b_sz, seq_len, self.size_q])? }; - let y = self.backend.linear_forward(&y, &self.o_proj_weight, None)?; + let y = self.o_proj_weight.forward(&y, None, &*self.backend)?; Ok(y) } diff --git a/cake-core/src/models/common/mlp.rs b/cake-core/src/models/common/mlp.rs index b7fe621..4082c61 100644 --- a/cake-core/src/models/common/mlp.rs +++ b/cake-core/src/models/common/mlp.rs @@ -4,13 +4,39 @@ use candle_core::{Result, Tensor, D}; use candle_nn::VarBuilder; use crate::backends::ComputeBackend; +use crate::utils::quantized_linear::LinearWeight; + +/// Attempt to load a quantized linear weight from a VarBuilder. +/// +/// Returns `Some(LinearWeight::Quantized)` if the prefix has `.scales` and +/// the `.weight` tensor is U32 (packed 4-bit from MetalMlxBackend). +/// Returns `None` if the tensor is not quantized (caller falls back to dense). +pub(crate) fn try_load_quantized(vb: &VarBuilder) -> Result> { + if !vb.contains_tensor("scales") { + return Ok(None); + } + // Try loading the weight — if it's U32, it's packed 4-bit. + if let Ok(weight) = vb.get_unchecked("weight") { + if weight.dtype() == candle_core::DType::U32 { + let scales = vb.get_unchecked("scales")?; + let biases = vb.get_unchecked("biases")?; + let packed_cols = weight.dim(1)?; + let num_quant_groups = scales.dim(1)?; + let group_size = (packed_cols * 8) / num_quant_groups; + return Ok(Some(LinearWeight::quantized( + weight, scales, biases, group_size, + ))); + } + } + Ok(None) +} /// Multi-perceptron implementation with fused gate+up projection. #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Clone)] pub struct MLP { - gate_up_proj_weight: Tensor, - down_proj_weight: Tensor, + gate_up_proj_weight: LinearWeight, + down_proj_weight: LinearWeight, intermediate_size: usize, use_gelu: bool, backend: Arc, @@ -19,35 +45,66 @@ pub struct MLP { impl MLP { /// Execute MLP(x). pub fn forward(&self, x: &Tensor) -> Result { - let fused = self.backend.linear_forward(x, &self.gate_up_proj_weight, None)?; + let fused = self.gate_up_proj_weight.forward(x, None, &*self.backend)?; let gate = fused.narrow(D::Minus1, 0, self.intermediate_size)?; let up = fused.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?; let x = if self.use_gelu { (self.backend.gelu(&gate)? * up)? } else { - self.backend.silu_mul(&gate.contiguous()?, &up.contiguous()?)? + self.backend + .silu_mul(&gate.contiguous()?, &up.contiguous()?)? }; - self.backend.linear_forward(&x, &self.down_proj_weight, None) + self.down_proj_weight.forward(&x, None, &*self.backend) } /// Load this block from the VarBuilder given the specific configuration. - pub fn load(vb: VarBuilder, cfg: &super::Config, backend: Arc) -> Result { + pub fn load( + vb: VarBuilder, + cfg: &super::Config, + backend: Arc, + ) -> Result { let h_size = cfg.hidden_size; let i_size = cfg.intermediate_size; - let gate_up_w = if cfg.fused_gate_up_proj { + let gate_up_proj_weight = if cfg.fused_gate_up_proj { // Phi-3/4 style: weights already fused as 'gate_up_proj' - vb.pp("gate_up_proj").get((2 * i_size, h_size), "weight")? + let vb_proj = vb.pp("gate_up_proj"); + if let Some(qw) = try_load_quantized(&vb_proj)? { + qw + } else { + LinearWeight::Dense(backend.preprocess_linear_weight( + &vb_proj.get((2 * i_size, h_size), "weight")?, + )?) + } } else { // Standard: fuse gate_proj and up_proj into a single matmul - let gate_w = vb.pp("gate_proj").get((i_size, h_size), "weight")?; - let up_w = vb.pp("up_proj").get((i_size, h_size), "weight")?; - Tensor::cat(&[&gate_w, &up_w], 0)? + let gate_q = try_load_quantized(&vb.pp("gate_proj"))?; + let up_q = try_load_quantized(&vb.pp("up_proj"))?; + match (gate_q, up_q) { + (Some(g), Some(u)) => { + // Both quantized: fuse packed weights along dim 0 (out_features) + fuse_quantized(&[g, u])? + } + _ => { + // Dense: fuse raw weights then preprocess + let gate_w = vb.pp("gate_proj").get((i_size, h_size), "weight")?; + let up_w = vb.pp("up_proj").get((i_size, h_size), "weight")?; + LinearWeight::Dense( + backend.preprocess_linear_weights(&[&gate_w, &up_w])?, + ) + } + } }; - let gate_up_proj_weight = backend.preprocess_linear_weight(&gate_up_w)?; - let down_w = vb.pp("down_proj").get((h_size, i_size), "weight")?; - let down_proj_weight = backend.preprocess_linear_weight(&down_w)?; + let down_proj_weight = { + let vb_down = vb.pp("down_proj"); + if let Some(qw) = try_load_quantized(&vb_down)? { + qw + } else { + let w = vb_down.get((h_size, i_size), "weight")?; + LinearWeight::Dense(backend.preprocess_linear_weight(&w)?) + } + }; Ok(Self { gate_up_proj_weight, @@ -58,3 +115,31 @@ impl MLP { }) } } + +/// Fuse multiple quantized linear weights by concatenating along dim 0. +pub(crate) fn fuse_quantized_pub(weights: &[LinearWeight]) -> Result { + fuse_quantized(weights) +} + +fn fuse_quantized(weights: &[LinearWeight]) -> Result { + use crate::utils::quantized_linear::QuantizedWeight; + + let qws: Vec<&QuantizedWeight> = weights + .iter() + .map(|w| match w { + LinearWeight::Quantized(qw) => Ok(qw), + LinearWeight::Dense(_) => candle_core::bail!("fuse_quantized: expected Quantized"), + }) + .collect::>()?; + + let packed_refs: Vec<&Tensor> = qws.iter().map(|qw| &qw.packed).collect(); + let scales_refs: Vec<&Tensor> = qws.iter().map(|qw| &qw.scales).collect(); + let biases_refs: Vec<&Tensor> = qws.iter().map(|qw| &qw.biases).collect(); + + Ok(LinearWeight::quantized( + Tensor::cat(&packed_refs, 0)?, + Tensor::cat(&scales_refs, 0)?, + Tensor::cat(&biases_refs, 0)?, + qws[0].group_size, + )) +} diff --git a/cake-core/src/utils/gptq.rs b/cake-core/src/utils/gptq.rs index fbc78c3..020c3fd 100644 --- a/cake-core/src/utils/gptq.rs +++ b/cake-core/src/utils/gptq.rs @@ -24,8 +24,8 @@ use std::path::Path; -use candle_core::{safetensors::MmapedSafetensors, DType, Device, Shape, Tensor}; -use candle_nn::{var_builder::SimpleBackend, Init, VarBuilder}; +use candle_core::{DType, Device, Shape, Tensor, safetensors::MmapedSafetensors}; +use candle_nn::{Init, VarBuilder, var_builder::SimpleBackend}; /// Check whether a model uses 4-bit quantization by inspecting its config.json. /// Detects both standard GPTQ (`quant_method: "gptq"`) and affine 4-bit @@ -41,7 +41,8 @@ pub fn is_gptq_quantized(config_path: &Path) -> bool { for root in [&json, json.get("text_config").unwrap_or(&json)] { if let Some(qc) = root.get("quantization_config") { // Standard GPTQ: quant_method == "gptq" - let is_gptq = qc.get("quant_method") + let is_gptq = qc + .get("quant_method") .and_then(|qm| qm.as_str()) .map(|s| s == "gptq") .unwrap_or(false); @@ -49,11 +50,13 @@ pub fn is_gptq_quantized(config_path: &Path) -> bool { return true; } // Affine 4-bit: mode == "affine" && bits == 4 - let is_affine_4bit = qc.get("mode") + let is_affine_4bit = qc + .get("mode") .and_then(|m| m.as_str()) .map(|s| s == "affine") .unwrap_or(false) - && qc.get("bits") + && qc + .get("bits") .and_then(|b| b.as_u64()) .map(|b| b == 4) .unwrap_or(false); @@ -173,8 +176,14 @@ pub fn dequantize_packed_4bit( // Extract raw data — avoid Tensor intermediates for the hot path let pw: Vec = packed.flatten_all()?.to_vec1::()?; - let sc: Vec = scales.to_dtype(DType::F32)?.flatten_all()?.to_vec1::()?; - let bi: Vec = biases.to_dtype(DType::F32)?.flatten_all()?.to_vec1::()?; + let sc: Vec = scales + .to_dtype(DType::F32)? + .flatten_all()? + .to_vec1::()?; + let bi: Vec = biases + .to_dtype(DType::F32)? + .flatten_all()? + .to_vec1::()?; use rayon::prelude::*; let mut weight = vec![0f32; rows * cols]; @@ -198,6 +207,69 @@ pub fn dequantize_packed_4bit( Tensor::from_vec(weight, (rows, cols), &Device::Cpu) } +/// Dequantize directly to F16, skipping the F32 intermediate. +/// +/// This saves 50% peak memory vs `dequantize_packed_4bit` + `to_dtype(F16)`: +/// - Old path: 259 MiB F32 + 129.5 MiB F16 = 388.5 MiB peak per tensor +/// - New path: 129.5 MiB F16 only = 129.5 MiB peak per tensor +/// +/// The arithmetic (w4 * scale + bias) is done in F32 then truncated to F16 +/// per element, which matches the precision of the old path. +pub fn dequantize_packed_4bit_f16( + packed: &Tensor, + scales: &Tensor, + biases: &Tensor, + group_size: usize, +) -> candle_core::Result { + if packed.rank() == 3 { + let n = packed.dim(0)?; + let slices: Vec = (0..n) + .map(|i| { + let p = packed.get(i)?; + let s = scales.get(i)?; + let b = biases.get(i)?; + dequantize_packed_4bit_f16(&p, &s, &b, group_size) + }) + .collect::>()?; + return Tensor::stack(&slices, 0); + } + let (rows, packed_cols) = packed.dims2()?; + let cols = packed_cols * 8; + let (_, groups) = scales.dims2()?; + + let pw: Vec = packed.flatten_all()?.to_vec1::()?; + let sc: Vec = scales + .to_dtype(DType::F32)? + .flatten_all()? + .to_vec1::()?; + let bi: Vec = biases + .to_dtype(DType::F32)? + .flatten_all()? + .to_vec1::()?; + + use half::f16; + use rayon::prelude::*; + let mut weight = vec![f16::ZERO; rows * cols]; + weight + .par_chunks_mut(cols) + .enumerate() + .for_each(|(i, row)| { + for pc in 0..packed_cols { + let packed_val = pw[i * packed_cols + pc]; + for bit in 0..8u32 { + let j = pc * 8 + bit as usize; + let w4 = ((packed_val >> (bit * 4)) & 0xF) as f32; + let g = j / group_size; + let scale = sc[i * groups + g]; + let bias = bi[i * groups + g]; + row[j] = f16::from_f32(w4 * scale + bias); + } + } + }); + + Tensor::from_vec(weight, (rows, cols), &Device::Cpu) +} + /// Custom VarBuilder backend that transparently dequantizes GPTQ weights. /// /// When asked for `foo.weight`, checks if `foo.qweight` exists and, if so, @@ -209,12 +281,7 @@ struct GptqBackend { } impl GptqBackend { - fn load_tensor( - &self, - name: &str, - dtype: DType, - dev: &Device, - ) -> candle_core::Result { + fn load_tensor(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result { // Strip the ".weight" suffix to get the parameter prefix. let prefix = name.strip_suffix(".weight").unwrap_or(name); let qweight_name = format!("{prefix}.qweight"); @@ -235,8 +302,15 @@ impl GptqBackend { let packed = self.inner.load(name, &Device::Cpu)?; let scales = self.inner.load(&scales_name, &Device::Cpu)?; let biases = self.inner.load(&biases_name, &Device::Cpu)?; - let weight = dequantize_packed_4bit(&packed, &scales, &biases, self.group_size)?; - weight.to_dtype(dtype)?.to_device(dev) + // Use F16-native dequant when target is F16 to halve peak memory + // (skips the 259 MiB F32 intermediate for large tensors). + let weight = if dtype == DType::F16 { + dequantize_packed_4bit_f16(&packed, &scales, &biases, self.group_size)? + } else { + let w = dequantize_packed_4bit(&packed, &scales, &biases, self.group_size)?; + w.to_dtype(dtype)? + }; + weight.to_device(dev) } else { // Non-quantized tensor — load directly. self.inner.load(name, dev)?.to_dtype(dtype) @@ -308,6 +382,183 @@ pub unsafe fn load_gptq_var_builder<'a>( Ok(VarBuilder::from_backend(backend, dtype, device.clone())) } +/// Create a VarBuilder for MLX-style packed 4-bit quantized models. +/// +/// Reuses the `GptqBackend` since MLX and GPTQ-affine share the same +/// dequantization path (packed uint32 + scales + biases). +/// +/// # Safety +/// +/// Inherits the mmap safety requirements from `MmapedSafetensors`. +pub unsafe fn load_mlx_var_builder<'a>( + filenames: &[std::path::PathBuf], + dtype: DType, + device: &Device, + group_size: usize, +) -> anyhow::Result> { + let inner = MmapedSafetensors::multi(filenames)?; + + let scales_count = inner + .tensors() + .iter() + .filter(|(name, _)| name.ends_with(".scales")) + .count(); + log::info!( + "MLX 4-bit model: {} quantized tensors will be dequantized at load time (group_size={})", + scales_count, + group_size, + ); + + let backend: Box = Box::new(GptqBackend { inner, group_size }); + Ok(VarBuilder::from_backend(backend, dtype, device.clone())) +} + +/// VarBuilder backend for Metal-native MLX 4-bit quantized models. +/// +/// Unlike `GptqBackend`, this does NOT dequantize packed weights. Instead: +/// - `{prefix}.weight` → loaded as raw U32 packed tensor (no dequant) +/// - `{prefix}.scales` → loaded as F16 +/// - `{prefix}.biases` → loaded as F16 +/// - Non-quantized tensors → loaded normally with dtype conversion +/// +/// This keeps weights at 0.5 bytes/element on Metal, enabling the fused +/// `q4_matmul_f16` kernel to dequantize on-the-fly during matmul. +struct MetalMlxBackend { + inner: MmapedSafetensors, + group_size: usize, +} + +impl MetalMlxBackend { + /// Dequantize a quantized tensor (fallback for non-linear layers like embeddings). + fn load_dequantized( + &self, + name: &str, + dtype: DType, + dev: &Device, + ) -> candle_core::Result { + let prefix = name.strip_suffix(".weight").unwrap_or(name); + let scales_name = format!("{prefix}.scales"); + let biases_name = format!("{prefix}.biases"); + + if name.ends_with(".weight") && self.inner.get(&scales_name).is_ok() { + let packed = self.inner.load(name, &Device::Cpu)?; + let scales = self.inner.load(&scales_name, &Device::Cpu)?; + let biases = self.inner.load(&biases_name, &Device::Cpu)?; + let weight = if dtype == DType::F16 { + dequantize_packed_4bit_f16(&packed, &scales, &biases, self.group_size)? + } else { + let w = dequantize_packed_4bit(&packed, &scales, &biases, self.group_size)?; + w.to_dtype(dtype)? + }; + weight.to_device(dev) + } else { + self.inner.load(name, dev)?.to_dtype(dtype) + } + } + + /// Load a tensor in its raw format (packed U32 for quantized, native for others). + /// Used by get_unchecked() for the fused q4 kernel path. + fn load_tensor(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result { + let prefix = name.strip_suffix(".weight").unwrap_or(name); + let scales_name = format!("{prefix}.scales"); + + // Check if this is a quantized tensor (has matching .scales) + if name.ends_with(".weight") && self.inner.get(&scales_name).is_ok() { + // Return packed U32 weight directly — no dequantization. + self.inner.load(name, dev) + } else if name.ends_with(".scales") || name.ends_with(".biases") { + // Scales and biases: load as F16 (their native format). + let t = self.inner.load(name, dev)?; + t.to_dtype(DType::F16) + } else { + // Non-quantized tensor — standard load with dtype conversion. + self.inner.load(name, dev)?.to_dtype(dtype) + } + } +} + +impl SimpleBackend for MetalMlxBackend { + fn get( + &self, + s: Shape, + name: &str, + _h: Init, + dtype: DType, + dev: &Device, + ) -> candle_core::Result { + // Shape-checked path: try raw load first (works for non-quantized tensors + // and quantized linear layers where caller expects packed shape). + let tensor = self.load_tensor(name, dtype, dev)?; + if tensor.shape() == &s { + return Ok(tensor); + } + // Shape mismatch — this is a non-linear layer (embedding, norm) that has + // quantized weights but needs full F16. Fall back to dequantization. + let dequantized = self.load_dequantized(name, dtype, dev)?; + if dequantized.shape() != &s { + Err(candle_core::Error::UnexpectedShape { + msg: format!("shape mismatch for {name}"), + expected: s, + got: dequantized.shape().clone(), + } + .bt())? + } + Ok(dequantized) + } + + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result { + self.load_tensor(name, dtype, dev) + } + + fn contains_tensor(&self, name: &str) -> bool { + if self.inner.get(name).is_ok() { + return true; + } + let prefix = name.strip_suffix(".weight").unwrap_or(name); + self.inner.get(&format!("{prefix}.scales")).is_ok() + } +} + +/// Create a VarBuilder for Metal-native MLX 4-bit quantized models. +/// +/// Unlike `load_mlx_var_builder`, this does NOT dequantize packed weights. +/// Quantized tensors are returned raw (U32 packed, F16 scales/biases) so +/// the fused `q4_matmul_f16` kernel can operate directly on packed data. +/// +/// # Safety +/// +/// Inherits the mmap safety requirements from `MmapedSafetensors`. +pub unsafe fn load_metal_mlx_var_builder<'a>( + filenames: &[std::path::PathBuf], + dtype: DType, + device: &Device, + group_size: usize, +) -> anyhow::Result> { + let inner = MmapedSafetensors::multi(filenames)?; + + let scales_count = inner + .tensors() + .iter() + .filter(|(name, _)| name.ends_with(".scales")) + .count(); + log::info!( + "MLX 4-bit model on Metal: {} quantized tensors will use fused q4 kernel (group_size={}, 4x memory reduction)", + scales_count, + group_size, + ); + + let backend: Box = Box::new(MetalMlxBackend { inner, group_size }); + Ok(VarBuilder::from_backend(backend, dtype, device.clone())) +} + +/// Returns the group_size stored in this backend (used to pass quantization +/// info to model loading code). +pub fn metal_mlx_group_size() -> Option { + // This is a compile-time marker; actual group_size comes from config.json + // and is passed through MlxQuantization. + None +} + #[cfg(test)] mod tests { use super::*; diff --git a/cake-core/src/utils/mlx_quant.rs b/cake-core/src/utils/mlx_quant.rs new file mode 100644 index 0000000..9925b7f --- /dev/null +++ b/cake-core/src/utils/mlx_quant.rs @@ -0,0 +1,226 @@ +//! MLX 4-bit quantization detection. +//! +//! MLX-community models store linear-layer weights as three tensors: +//! - `*.weight` — uint32, shape `(rows, cols / 8)`: 8 x 4-bit values packed per int32 +//! - `*.scales` — f16, shape `(rows, groups)`: one scale per group +//! - `*.biases` — f16, shape `(rows, groups)`: one bias per group +//! +//! The config.json uses a `"quantization"` key (not `"quantization_config"`): +//! `{"quantization": {"group_size": 64, "bits": 4}}` +//! +//! Dequantization formula: `w_dequant[i, j] = w4(i, j) * scale(i, group(j)) + bias(i, group(j))` +//! +//! This is identical to the affine 4-bit path in `gptq::dequantize_packed_4bit`, +//! so we reuse the GPTQ backend for actual dequantization. + +use std::path::Path; + +/// Check whether a model uses MLX-style packed quantization. +/// +/// Detects the `"quantization"` key (MLX convention) with `bits: 4` and no +/// `quant_method` field (which would indicate GPTQ/FP8 instead). +/// Also catches `"quantization_config"` entries that lack `quant_method` and `mode`. +pub fn is_mlx_quantized(config_path: &Path) -> bool { + let Ok(data) = std::fs::read_to_string(config_path) else { + return false; + }; + let Ok(json) = serde_json::from_str::(&data) else { + return false; + }; + + for root in [&json, json.get("text_config").unwrap_or(&json)] { + // Primary MLX key: "quantization" (not "quantization_config") + if let Some(q) = root.get("quantization") { + if is_mlx_quant_block(q) { + return true; + } + } + // Some MLX models also populate "quantization_config" with the same data. + // Catch those that slipped past GPTQ detection (no quant_method, no mode). + if let Some(qc) = root.get("quantization_config") { + if is_mlx_quant_block(qc) { + return true; + } + } + } + false +} + +/// Returns true if the JSON block looks like the implemented MLX quantization: +/// has `bits: 4`, has `group_size`, and does NOT have `quant_method`. +fn is_mlx_quant_block(qc: &serde_json::Value) -> bool { + // Must not have quant_method — that's GPTQ or FP8 + if qc.get("quant_method").is_some() { + return false; + } + // Only detect bit widths we actually implement dequant for. + // dequantize_packed_4bit assumes 8 × 4-bit values per uint32. + let has_bits = qc + .get("bits") + .and_then(|b| b.as_u64()) + .map(|b| b == 4) + .unwrap_or(false); + let has_group_size = qc.get("group_size").is_some(); + has_bits && has_group_size +} + +/// Read the group_size from MLX quantization config (defaults to 64). +pub fn mlx_group_size(config_path: &Path) -> usize { + let Ok(data) = std::fs::read_to_string(config_path) else { + return 64; + }; + let Ok(json) = serde_json::from_str::(&data) else { + return 64; + }; + for root in [&json, json.get("text_config").unwrap_or(&json)] { + for key in ["quantization", "quantization_config"] { + if let Some(gs) = root + .get(key) + .and_then(|q| q.get("group_size")) + .and_then(|v| v.as_u64()) + { + return gs as usize; + } + } + } + 64 +} + +/// Read the bits from MLX quantization config (defaults to 4). +pub fn mlx_bits(config_path: &Path) -> usize { + let Ok(data) = std::fs::read_to_string(config_path) else { + return 4; + }; + let Ok(json) = serde_json::from_str::(&data) else { + return 4; + }; + for root in [&json, json.get("text_config").unwrap_or(&json)] { + for key in ["quantization", "quantization_config"] { + if let Some(bits) = root + .get(key) + .and_then(|q| q.get("bits")) + .and_then(|v| v.as_u64()) + { + return bits as usize; + } + } + } + 4 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_detects_mlx_quantization_key() { + let dir = tempfile::tempdir().unwrap(); + let cfg = dir.path().join("config.json"); + std::fs::write(&cfg, r#"{"quantization": {"group_size": 64, "bits": 4}}"#).unwrap(); + assert!(is_mlx_quantized(&cfg)); + } + + #[test] + fn test_detects_mlx_quantization_config_key() { + let dir = tempfile::tempdir().unwrap(); + let cfg = dir.path().join("config.json"); + std::fs::write( + &cfg, + r#"{"quantization_config": {"group_size": 64, "bits": 4}}"#, + ) + .unwrap(); + assert!(is_mlx_quantized(&cfg)); + } + + #[test] + fn test_rejects_gptq_quant_method() { + let dir = tempfile::tempdir().unwrap(); + let cfg = dir.path().join("config.json"); + std::fs::write( + &cfg, + r#"{"quantization_config": {"quant_method": "gptq", "bits": 4, "group_size": 128}}"#, + ) + .unwrap(); + assert!(!is_mlx_quantized(&cfg)); + } + + #[test] + fn test_rejects_fp8_quant_method() { + let dir = tempfile::tempdir().unwrap(); + let cfg = dir.path().join("config.json"); + std::fs::write( + &cfg, + r#"{"quantization_config": {"quant_method": "fp8", "bits": 8, "group_size": 128}}"#, + ) + .unwrap(); + assert!(!is_mlx_quantized(&cfg)); + } + + #[test] + fn test_rejects_no_quantization() { + let dir = tempfile::tempdir().unwrap(); + let cfg = dir.path().join("config.json"); + std::fs::write(&cfg, r#"{"hidden_size": 4096}"#).unwrap(); + assert!(!is_mlx_quantized(&cfg)); + } + + #[test] + fn test_rejects_3bit_mlx() { + // 3-bit packing is not implemented — dequant assumes 8 × 4-bit per uint32 + let dir = tempfile::tempdir().unwrap(); + let cfg = dir.path().join("config.json"); + std::fs::write(&cfg, r#"{"quantization": {"group_size": 64, "bits": 3}}"#).unwrap(); + assert!(!is_mlx_quantized(&cfg)); + } + + #[test] + fn test_group_size_from_quantization_key() { + let dir = tempfile::tempdir().unwrap(); + let cfg = dir.path().join("config.json"); + std::fs::write(&cfg, r#"{"quantization": {"group_size": 32, "bits": 4}}"#).unwrap(); + assert_eq!(mlx_group_size(&cfg), 32); + } + + #[test] + fn test_group_size_default() { + let dir = tempfile::tempdir().unwrap(); + let cfg = dir.path().join("config.json"); + std::fs::write(&cfg, r#"{"hidden_size": 4096}"#).unwrap(); + assert_eq!(mlx_group_size(&cfg), 64); + } + + #[test] + fn test_bits_from_config() { + let dir = tempfile::tempdir().unwrap(); + let cfg = dir.path().join("config.json"); + std::fs::write(&cfg, r#"{"quantization": {"group_size": 64, "bits": 3}}"#).unwrap(); + assert_eq!(mlx_bits(&cfg), 3); + } + + #[test] + fn test_nested_text_config() { + let dir = tempfile::tempdir().unwrap(); + let cfg = dir.path().join("config.json"); + std::fs::write( + &cfg, + r#"{"text_config": {"quantization": {"group_size": 64, "bits": 4}}}"#, + ) + .unwrap(); + assert!(is_mlx_quantized(&cfg)); + assert_eq!(mlx_group_size(&cfg), 64); + } + + #[test] + fn test_with_affine_mode_still_detected() { + // MLX models sometimes have mode: "affine" — should still be detected + // since they don't have quant_method + let dir = tempfile::tempdir().unwrap(); + let cfg = dir.path().join("config.json"); + std::fs::write( + &cfg, + r#"{"quantization": {"group_size": 64, "bits": 4, "mode": "affine"}}"#, + ) + .unwrap(); + assert!(is_mlx_quantized(&cfg)); + } +} diff --git a/cake-core/src/utils/mod.rs b/cake-core/src/utils/mod.rs index 2a827d9..04b6c83 100644 --- a/cake-core/src/utils/mod.rs +++ b/cake-core/src/utils/mod.rs @@ -1,11 +1,13 @@ //! Utility functions and abstractions. -pub mod fp8; #[cfg(feature = "flash-attn")] pub mod flash_attn; +pub mod fp8; pub mod gguf; pub mod gptq; pub mod hf; +pub mod mlx_quant; +pub mod quantized_linear; pub mod models; pub mod native_dtype_backend; pub mod split; @@ -49,6 +51,12 @@ pub trait Quantization: Send + Sync { None } + /// Returns the MLX quantization group size if this is MLX 4-bit, None otherwise. + /// Used by model loading code to detect the fused q4 Metal path. + fn mlx_group_size(&self) -> Option { + None + } + /// Estimate in-memory layer size given on-disk size and target dtype bytes. /// Default: no expansion (on-disk size = in-memory size). fn estimate_layer_vram(&self, on_disk_bytes: u64, _dtype_bytes: u64) -> u64 { @@ -124,6 +132,60 @@ impl Quantization for GptqQuantization { } } +/// MLX packed 4-bit quantization — wraps `gptq::load_mlx_var_builder`. +/// +/// MLX-community models pack 8 x 4-bit values per uint32 along the output +/// dimension, with per-group scales and biases. The dequantization math is +/// identical to GPTQ-affine, so we reuse the same backend. +pub struct MlxQuantization { + pub group_size: usize, + pub bits: usize, +} + +impl Quantization for MlxQuantization { + fn name(&self) -> &str { + "mlx" + } + + fn mlx_group_size(&self) -> Option { + Some(self.group_size) + } + + unsafe fn load_var_builder<'a>( + &self, + filenames: &[PathBuf], + dtype: DType, + device: &Device, + ) -> Result> { + // On Metal: use the fused q4 backend that keeps weights packed (4x memory savings). + if matches!(device, Device::Metal(_)) { + log::info!("MLX 4-bit on Metal: using fused q4 kernel path (no dequantization)"); + return gptq::load_metal_mlx_var_builder(filenames, dtype, device, self.group_size) + .map_err(|e| anyhow!("can't create metal mlx varbuilder: {e:?}")); + } + // Non-Metal: dequantize at load time (existing path). + gptq::load_mlx_var_builder(filenames, dtype, device, self.group_size) + .map_err(|e| anyhow!("can't create mlx varbuilder: {e:?}")) + } + + fn estimate_layer_vram(&self, on_disk_bytes: u64, _dtype_bytes: u64) -> u64 { + // On Metal with fused kernel: weights stay packed at ~0.5 bytes/element. + // Only scales+biases expand, but they're small (1/group_size of total). + // Conservative: assume 25% of dequant expansion (vs 4x full expansion). + #[cfg(feature = "metal")] + { + // Packed weights + scales/biases overhead ≈ 0.6 bytes/element. + // Much less than full dequant (dtype_bytes per element). + on_disk_bytes + on_disk_bytes / 4 + } + #[cfg(not(feature = "metal"))] + { + // Non-Metal: full dequant expansion. + on_disk_bytes * _dtype_bytes * 2 + } + } +} + /// Detect quantization strategy from a model's config.json. pub fn detect_quantization(config_path: &Path) -> Box { if fp8::is_fp8_quantized(config_path) { @@ -133,6 +195,14 @@ pub fn detect_quantization(config_path: &Path) -> Box { let gs = gptq::gptq_group_size(config_path); log::info!("model uses GPTQ quantization (group_size={gs}) — weights will be dequantized at load time"); Box::new(GptqQuantization { group_size: gs }) + } else if mlx_quant::is_mlx_quantized(config_path) { + let gs = mlx_quant::mlx_group_size(config_path); + let bits = mlx_quant::mlx_bits(config_path); + log::info!("model uses MLX {bits}-bit quantization (group_size={gs})"); + Box::new(MlxQuantization { + group_size: gs, + bits, + }) } else { Box::new(NoQuantization) } @@ -199,10 +269,11 @@ pub fn load_safetensors_paths_from_index( safetensors_files.insert(file.to_string()); } } - let safetensors_files = safetensors_files + let mut safetensors_files: Vec<_> = safetensors_files .iter() .map(|v| parent_dir.join(v)) - .collect::>(); + .collect(); + safetensors_files.sort(); Ok(safetensors_files) } @@ -308,10 +379,8 @@ pub fn load_var_builder_for_local_layers<'a>( } } - let filenames: Vec = needed_shards - .iter() - .map(|f| parent_dir.join(f)) - .collect(); + let mut filenames: Vec = needed_shards.iter().map(|f| parent_dir.join(f)).collect(); + filenames.sort(); log::info!( "loading {} of {} shard file(s) for local layers", @@ -369,7 +438,8 @@ pub fn load_var_builder_for_specific_layers<'a>( .collect::>() .len(); - let filenames: Vec = needed_shards.iter().map(|f| parent_dir.join(f)).collect(); + let mut filenames: Vec = needed_shards.iter().map(|f| parent_dir.join(f)).collect(); + filenames.sort(); log::info!( "loading {} of {} shard file(s) for {} layers", @@ -511,6 +581,65 @@ mod tests { assert_eq!(q.name(), "gptq"); } + #[test] + fn detect_quantization_mlx_quantization_key() { + let tmp = tempfile::tempdir().unwrap(); + let cfg_path = tmp.path().join("config.json"); + fs::write( + &cfg_path, + r#"{"quantization": {"group_size": 64, "bits": 4}}"#, + ) + .unwrap(); + let q = detect_quantization(&cfg_path); + assert_eq!(q.name(), "mlx"); + } + + #[test] + fn detect_quantization_mlx_quantization_config_key() { + let tmp = tempfile::tempdir().unwrap(); + let cfg_path = tmp.path().join("config.json"); + fs::write( + &cfg_path, + r#"{"quantization_config": {"group_size": 64, "bits": 4}}"#, + ) + .unwrap(); + let q = detect_quantization(&cfg_path); + assert_eq!(q.name(), "mlx"); + } + + #[test] + fn detect_quantization_mlx_not_confused_with_gptq() { + let tmp = tempfile::tempdir().unwrap(); + let cfg_path = tmp.path().join("config.json"); + fs::write( + &cfg_path, + r#"{"quantization_config": {"quant_method": "gptq", "bits": 4, "group_size": 128}}"#, + ) + .unwrap(); + let q = detect_quantization(&cfg_path); + assert_eq!(q.name(), "gptq"); + } + + #[test] + fn mlx_quantization_estimate_layer_vram_expansion() { + let q = MlxQuantization { + group_size: 64, + bits: 4, + }; + #[cfg(feature = "metal")] + { + // Metal fused kernel: weights stay packed, ~1.25x on-disk size. + assert_eq!(q.estimate_layer_vram(1000, 2), 1250); + assert_eq!(q.estimate_layer_vram(1000, 4), 1250); + } + #[cfg(not(feature = "metal"))] + { + // Non-Metal: full dequant expansion (4x for F16, 8x for F32). + assert_eq!(q.estimate_layer_vram(1000, 2), 4000); + assert_eq!(q.estimate_layer_vram(1000, 4), 8000); + } + } + #[test] fn no_quantization_estimate_layer_vram_passthrough() { let q = NoQuantization; diff --git a/cake-core/src/utils/quantized_linear.rs b/cake-core/src/utils/quantized_linear.rs new file mode 100644 index 0000000..ebee228 --- /dev/null +++ b/cake-core/src/utils/quantized_linear.rs @@ -0,0 +1,158 @@ +//! Quantized linear layer support for fused 4-bit Metal matmul. +//! +//! [`QuantizedWeight`] stores packed 4-bit weights (U32), per-group F16 scales, +//! and per-group F16 biases on a Metal device without dequantizing. +//! [`LinearWeight`] is an enum dispatching between standard dense (`Tensor`) +//! weights and quantized weights, used by MLP and Attention layers. + +use candle_core::{Result, Tensor}; + +use crate::backends::ComputeBackend; + +/// Packed 4-bit weight data kept on-device (Metal) without dequantization. +/// +/// Memory layout matches the Phase 1 `q4_matmul_f16` kernel: +/// - `packed`: (out_features, in_features/8) U32 -- 8 nibbles per U32, LSB-first +/// - `scales`: (out_features, num_groups) F16 +/// - `biases`: (out_features, num_groups) F16 +#[derive(Debug, Clone)] +pub struct QuantizedWeight { + /// Packed 4-bit weight tensor, U32, on Metal device. + pub packed: Tensor, + /// Per-group scale factors, F16, on Metal device. + pub scales: Tensor, + /// Per-group bias offsets, F16, on Metal device. + pub biases: Tensor, + /// Number of elements per quantization group. + pub group_size: usize, +} + +/// A linear layer weight that is either dense (standard F16 Tensor) or +/// quantized (packed 4-bit with scales/biases). Used as a drop-in replacement +/// for raw `Tensor` weight fields in MLP and Attention. +#[derive(Debug, Clone)] +pub enum LinearWeight { + /// Standard dense weight tensor (pre-transposed for Metal backend). + Dense(Tensor), + /// Quantized 4-bit weight kept packed on Metal. + Quantized(QuantizedWeight), +} + +impl LinearWeight { + /// Perform `x @ weight^T + bias`, dispatching to the fused q4 kernel + /// for quantized weights on Metal. + pub fn forward( + &self, + x: &Tensor, + bias: Option<&Tensor>, + backend: &dyn ComputeBackend, + ) -> Result { + match self { + LinearWeight::Dense(weight) => backend.linear_forward(x, weight, bias), + LinearWeight::Quantized(qw) => { + let out = backend.q4_linear_forward( + &qw.packed, + &qw.scales, + &qw.biases, + x, + qw.group_size, + )?; + match bias { + Some(b) => out.broadcast_add(b), + None => Ok(out), + } + } + } + } + + /// Wrap a dense weight tensor. + pub fn dense(weight: Tensor) -> Self { + LinearWeight::Dense(weight) + } + + /// Wrap quantized weight components. + pub fn quantized(packed: Tensor, scales: Tensor, biases: Tensor, group_size: usize) -> Self { + LinearWeight::Quantized(QuantizedWeight { + packed, + scales, + biases, + group_size, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use candle_core::{DType, Device}; + + #[test] + fn test_linear_weight_dense_forward() { + let backend = crate::backends::create_backend(&Device::Cpu); + // Simple 3x2 weight, input 1x2 + let w = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], (3, 2), &Device::Cpu) + .unwrap(); + let w = backend.preprocess_linear_weight(&w).unwrap(); + let lw = LinearWeight::Dense(w); + let x = Tensor::from_vec(vec![1.0f32, 1.0], (1, 2), &Device::Cpu).unwrap(); + let out = lw.forward(&x, None, &*backend).unwrap(); + let vals: Vec = out.flatten_all().unwrap().to_vec1().unwrap(); + // weight rows: [1,2], [3,4], [5,6] -> x@w^T = [3, 7, 11] + assert!((vals[0] - 3.0).abs() < 1e-5); + assert!((vals[1] - 7.0).abs() < 1e-5); + assert!((vals[2] - 11.0).abs() < 1e-5); + } + + #[test] + fn test_linear_weight_dense_with_bias() { + let backend = crate::backends::create_backend(&Device::Cpu); + let w = Tensor::from_vec(vec![1.0f32, 0.0, 0.0, 1.0], (2, 2), &Device::Cpu).unwrap(); + let w = backend.preprocess_linear_weight(&w).unwrap(); + let lw = LinearWeight::Dense(w); + let x = Tensor::from_vec(vec![3.0f32, 5.0], (1, 2), &Device::Cpu).unwrap(); + let bias = Tensor::from_vec(vec![10.0f32, 20.0], 2, &Device::Cpu).unwrap(); + let out = lw.forward(&x, Some(&bias), &*backend).unwrap(); + let vals: Vec = out.flatten_all().unwrap().to_vec1().unwrap(); + // identity weight + bias: [3+10, 5+20] = [13, 25] + assert!((vals[0] - 13.0).abs() < 1e-5); + assert!((vals[1] - 25.0).abs() < 1e-5); + } + + #[test] + fn test_quantized_weight_cpu_fallback() { + let backend = crate::backends::create_backend(&Device::Cpu); + // in_features=8, out_features=2, group_size=8, 1 group + // All nibbles=1, scale=1.0, bias=0.0 -> all weights = 1.0 + let packed = + Tensor::from_vec(vec![0x11111111u32; 2], (2, 1), &Device::Cpu).unwrap(); + let scales = Tensor::from_vec(vec![1.0f32; 2], (2, 1), &Device::Cpu) + .unwrap() + .to_dtype(DType::F16) + .unwrap(); + let biases = Tensor::from_vec(vec![0.0f32; 2], (2, 1), &Device::Cpu) + .unwrap() + .to_dtype(DType::F16) + .unwrap(); + let qw = LinearWeight::quantized(packed, scales, biases, 8); + + let x = Tensor::from_vec(vec![1.0f32; 8], (1, 8), &Device::Cpu) + .unwrap() + .to_dtype(DType::F16) + .unwrap(); + let out = qw.forward(&x, None, &*backend).unwrap(); + let vals: Vec = out + .to_dtype(DType::F32) + .unwrap() + .flatten_all() + .unwrap() + .to_vec1() + .unwrap(); + // Each weight=1.0, x=[1;8] -> dot = 8.0 + for (i, &v) in vals.iter().enumerate() { + assert!( + (v - 8.0).abs() < 0.5, + "output[{i}] = {v}, expected ~8.0" + ); + } + } +} diff --git a/cake-core/tests/unit_tests/test_quantization.rs b/cake-core/tests/unit_tests/test_quantization.rs index 2ae40ac..6f6986f 100644 --- a/cake-core/tests/unit_tests/test_quantization.rs +++ b/cake-core/tests/unit_tests/test_quantization.rs @@ -1,4 +1,4 @@ -//! Tests for FP8 dequantization paths. +//! Tests for FP8 dequantization and fused 4-bit matmul paths. use candle_core::{DType, Device, Tensor}; @@ -16,3 +16,618 @@ fn test_fp8_to_f32() { } // Skip if F8 not supported on this platform } + +// ─── Fused 4-bit matmul tests ────────────────────────────────────────── + +/// CPU reference: dequantize packed 4-bit weights and multiply with x. +/// Returns (M, out_features) F32 result. +#[allow(clippy::too_many_arguments)] +fn cpu_q4_matmul_reference( + packed: &[u32], + scales: &[f32], + biases: &[f32], + x: &[f32], + m: usize, + in_features: usize, + out_features: usize, + group_size: usize, + num_groups: usize, +) -> Vec { + let packed_cols = in_features / 8; + let mut output = vec![0f32; m * out_features]; + for row in 0..m { + for col in 0..out_features { + let mut acc = 0f32; + for pc in 0..packed_cols { + let packed_val = packed[col * packed_cols + pc]; + for bit in 0..8u32 { + let j = pc * 8 + bit as usize; + let w4 = ((packed_val >> (bit * 4)) & 0xF) as f32; + let g = j / group_size; + let scale = scales[col * num_groups + g]; + let bias = biases[col * num_groups + g]; + let w = w4 * scale + bias; + acc += w * x[row * in_features + j]; + } + } + output[row * out_features + col] = acc; + } + } + output +} + +#[test] +fn test_q4_matmul_cpu_reference_known_values() { + // Tiny case: 1x16 activation, 2x16 weight (packed as 2x2 u32), group_size=8 + let in_features = 16; + let out_features = 2; + let group_size = 8; + let num_groups = in_features / group_size; // 2 + + // All nibbles = 1, scale=1.0, bias=0.0 → each weight = 1.0 + // x = [1.0; 16] → dot product = 16.0 for each output + let packed = vec![0x11111111u32; out_features * (in_features / 8)]; // 2 * 2 = 4 + let scales = vec![1.0f32; out_features * num_groups]; // 2 * 2 = 4 + let biases = vec![0.0f32; out_features * num_groups]; // 2 * 2 = 4 + let x = vec![1.0f32; in_features]; + + let result = cpu_q4_matmul_reference( + &packed, + &scales, + &biases, + &x, + 1, + in_features, + out_features, + group_size, + num_groups, + ); + assert_eq!(result.len(), 2); + assert!( + (result[0] - 16.0).abs() < 1e-5, + "expected 16.0, got {}", + result[0] + ); + assert!( + (result[1] - 16.0).abs() < 1e-5, + "expected 16.0, got {}", + result[1] + ); +} + +#[test] +fn test_q4_matmul_cpu_reference_varied_nibbles() { + // 1x8 activation, 1x8 weight, group_size=8 (1 group) + // packed[0] = 0x76543210 → nibbles [0,1,2,3,4,5,6,7] + // scale=0.5, bias=-1.0 → weights = [0*0.5-1, 1*0.5-1, ..., 7*0.5-1] + // = [-1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 2.5] + // x = [1,1,1,1,1,1,1,1] → dot = sum(weights) = -1+(-0.5)+0+0.5+1+1.5+2+2.5 = 6.0 + let packed = vec![0x76543210u32]; + let scales = vec![0.5f32]; + let biases = vec![-1.0f32]; + let x = vec![1.0f32; 8]; + + let result = cpu_q4_matmul_reference(&packed, &scales, &biases, &x, 1, 8, 1, 8, 1); + assert!((result[0] - 6.0).abs() < 1e-5, "expected 6.0, got {}", result[0]); +} + +#[cfg(feature = "metal")] +#[test] +fn test_q4_matmul_f16_metal_vs_cpu() { + use candle_core::utils::metal_is_available; + use half::f16; + + if !metal_is_available() { + return; // Skip on non-Metal platforms + } + let metal_device = match Device::new_metal(0) { + Ok(d) => d, + Err(_) => return, + }; + + // Dimensions: M=2, in_features=64, out_features=4, group_size=32 + let m = 2usize; + let in_features = 64usize; + let out_features = 4usize; + let group_size = 32usize; + let num_groups = in_features / group_size; // 2 + let packed_cols = in_features / 8; // 8 + + // Generate deterministic test data + let mut packed_data = vec![0u32; out_features * packed_cols]; + for (i, val) in packed_data.iter_mut().enumerate() { + // Varying nibble patterns + let base = (i * 7 + 3) as u32; + *val = 0; + for bit in 0..8u32 { + let nibble = (base + bit * 5) % 16; + *val |= nibble << (bit * 4); + } + } + + let mut scales_f32 = vec![0f32; out_features * num_groups]; + for (i, s) in scales_f32.iter_mut().enumerate() { + *s = 0.1 + (i as f32) * 0.05; + } + + let mut biases_f32 = vec![0f32; out_features * num_groups]; + for (i, b) in biases_f32.iter_mut().enumerate() { + *b = -0.5 + (i as f32) * 0.1; + } + + let mut x_f32 = vec![0f32; m * in_features]; + for (i, v) in x_f32.iter_mut().enumerate() { + *v = ((i % 17) as f32 - 8.0) / 16.0; + } + + // CPU reference (F32 precision) + let cpu_result = cpu_q4_matmul_reference( + &packed_data, + &scales_f32, + &biases_f32, + &x_f32, + m, + in_features, + out_features, + group_size, + num_groups, + ); + + // Metal path: create tensors on Metal device + let packed_tensor = + Tensor::from_vec(packed_data, (out_features, packed_cols), &metal_device).unwrap(); + let scales_f16: Vec = scales_f32.iter().map(|&v| f16::from_f32(v)).collect(); + let biases_f16: Vec = biases_f32.iter().map(|&v| f16::from_f32(v)).collect(); + let x_f16: Vec = x_f32.iter().map(|&v| f16::from_f32(v)).collect(); + + let scales_tensor = + Tensor::from_vec(scales_f16, (out_features, num_groups), &metal_device).unwrap(); + let biases_tensor = + Tensor::from_vec(biases_f16, (out_features, num_groups), &metal_device).unwrap(); + let x_tensor = Tensor::from_vec(x_f16, (m, in_features), &metal_device).unwrap(); + + let metal_result = cake_core::backends::q4_matmul_f16( + &packed_tensor, + &scales_tensor, + &biases_tensor, + &x_tensor, + group_size, + ) + .unwrap(); + + // Verify shape + assert_eq!(metal_result.dims(), &[m, out_features]); + assert_eq!(metal_result.dtype(), DType::F16); + + // Compare values with F16 tolerance + let metal_f32: Vec = metal_result + .to_device(&Device::Cpu) + .unwrap() + .to_dtype(DType::F32) + .unwrap() + .flatten_all() + .unwrap() + .to_vec1() + .unwrap(); + + let mut max_diff = 0f32; + let mut max_idx = 0usize; + for (i, (cpu, metal)) in cpu_result.iter().zip(metal_f32.iter()).enumerate() { + let diff = (cpu - metal).abs(); + if diff > max_diff { + max_diff = diff; + max_idx = i; + } + } + + // F16 accumulation tolerance: with 64 multiply-adds and F16 quantization of + // scales/biases/activations, expect up to ~0.5 absolute error. + let tolerance = 0.5f32; + assert!( + max_diff <= tolerance, + "q4_matmul_f16: max diff {max_diff} at index {max_idx} (cpu={} metal={}) exceeds tolerance {tolerance}", + cpu_result[max_idx], + metal_f32[max_idx] + ); +} + +#[cfg(feature = "metal")] +#[test] +fn test_q4_matmul_f16_metal_identity_weights() { + // Test with weights that are effectively identity-like: + // All nibbles=0, scale=0, bias=1 → all weights=1.0, so output = sum(x) per row. + use candle_core::utils::metal_is_available; + use half::f16; + + if !metal_is_available() { + return; + } + let metal_device = match Device::new_metal(0) { + Ok(d) => d, + Err(_) => return, + }; + + let m = 1usize; + let in_features = 32usize; + let out_features = 2usize; + let group_size = 32usize; + let num_groups = 1usize; + let packed_cols = in_features / 8; // 4 + + // All nibbles=0, scale=0, bias=1.0 → all dequantized weights = 1.0 + let packed_data = vec![0u32; out_features * packed_cols]; + let scales_f16 = vec![f16::from_f32(0.0); out_features * num_groups]; + let biases_f16 = vec![f16::from_f32(1.0); out_features * num_groups]; + + // x = [0.5; 32] → dot = 0.5 * 32 = 16.0 + let x_f16 = vec![f16::from_f32(0.5); m * in_features]; + + let packed_tensor = + Tensor::from_vec(packed_data, (out_features, packed_cols), &metal_device).unwrap(); + let scales_tensor = + Tensor::from_vec(scales_f16, (out_features, num_groups), &metal_device).unwrap(); + let biases_tensor = + Tensor::from_vec(biases_f16, (out_features, num_groups), &metal_device).unwrap(); + let x_tensor = Tensor::from_vec(x_f16, (m, in_features), &metal_device).unwrap(); + + let result = cake_core::backends::q4_matmul_f16( + &packed_tensor, + &scales_tensor, + &biases_tensor, + &x_tensor, + group_size, + ) + .unwrap(); + + let result_f32: Vec = result + .to_device(&Device::Cpu) + .unwrap() + .to_dtype(DType::F32) + .unwrap() + .flatten_all() + .unwrap() + .to_vec1() + .unwrap(); + + assert_eq!(result.dims(), &[1, 2]); + for (i, &v) in result_f32.iter().enumerate() { + assert!( + (v - 16.0).abs() < 0.1, + "output[{i}] = {v}, expected 16.0" + ); + } +} + +#[cfg(feature = "metal")] +#[test] +fn test_q4_matmul_f16_metal_batch() { + // Test with M > 1 to verify batched dispatch works + use candle_core::utils::metal_is_available; + use half::f16; + + if !metal_is_available() { + return; + } + let metal_device = match Device::new_metal(0) { + Ok(d) => d, + Err(_) => return, + }; + + let m = 4usize; + let in_features = 16usize; + let out_features = 3usize; + let group_size = 8usize; + let num_groups = 2usize; + let packed_cols = 2usize; + + // Simple: all nibbles=2, scale=1.0, bias=0.0 → weights=2.0 + // x row i = [i+1; 16] → dot = (i+1)*2.0*16 = 32*(i+1) + let packed_data = vec![0x22222222u32; out_features * packed_cols]; + let scales_f16 = vec![f16::from_f32(1.0); out_features * num_groups]; + let biases_f16 = vec![f16::from_f32(0.0); out_features * num_groups]; + + let mut x_f16 = Vec::with_capacity(m * in_features); + for row in 0..m { + for _ in 0..in_features { + x_f16.push(f16::from_f32((row + 1) as f32)); + } + } + + let packed_tensor = + Tensor::from_vec(packed_data, (out_features, packed_cols), &metal_device).unwrap(); + let scales_tensor = + Tensor::from_vec(scales_f16, (out_features, num_groups), &metal_device).unwrap(); + let biases_tensor = + Tensor::from_vec(biases_f16, (out_features, num_groups), &metal_device).unwrap(); + let x_tensor = Tensor::from_vec(x_f16, (m, in_features), &metal_device).unwrap(); + + let result = cake_core::backends::q4_matmul_f16( + &packed_tensor, + &scales_tensor, + &biases_tensor, + &x_tensor, + group_size, + ) + .unwrap(); + + assert_eq!(result.dims(), &[4, 3]); + + let result_f32: Vec = result + .to_device(&Device::Cpu) + .unwrap() + .to_dtype(DType::F32) + .unwrap() + .flatten_all() + .unwrap() + .to_vec1() + .unwrap(); + + for row in 0..m { + let expected = 32.0 * (row + 1) as f32; + for col in 0..out_features { + let actual = result_f32[row * out_features + col]; + assert!( + (actual - expected).abs() < 1.0, + "row={row} col={col}: expected {expected}, got {actual}" + ); + } + } +} + +// ─── LinearWeight + QuantizedLinear tests ────────────────────────── + +#[test] +fn test_linear_weight_dense_matches_backend() { + // Verify LinearWeight::Dense produces the same result as direct backend.linear_forward + use cake_core::utils::quantized_linear::LinearWeight; + + let backend = cake_core::backends::create_backend(&Device::Cpu); + let w = Tensor::from_vec( + vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], + (3, 2), + &Device::Cpu, + ) + .unwrap(); + let w_preprocessed = backend.preprocess_linear_weight(&w).unwrap(); + let lw = LinearWeight::Dense(w_preprocessed.clone()); + let x = Tensor::from_vec(vec![1.0f32, 2.0], (1, 2), &Device::Cpu).unwrap(); + + let direct = backend.linear_forward(&x, &w_preprocessed, None).unwrap(); + let via_lw = lw.forward(&x, None, &*backend).unwrap(); + + let direct_vals: Vec = direct.flatten_all().unwrap().to_vec1().unwrap(); + let lw_vals: Vec = via_lw.flatten_all().unwrap().to_vec1().unwrap(); + assert_eq!(direct_vals.len(), lw_vals.len()); + for (d, l) in direct_vals.iter().zip(lw_vals.iter()) { + assert!((d - l).abs() < 1e-6, "mismatch: direct={d}, lw={l}"); + } +} + +#[test] +fn test_linear_weight_quantized_cpu_fallback_known_values() { + // Quantized path with known values: all nibbles=2, scale=0.5, bias=-1.0 + // weight[j] = 2 * 0.5 + (-1.0) = 0.0 for all elements. + // So output should be 0 for any input. + use cake_core::utils::quantized_linear::LinearWeight; + + let backend = cake_core::backends::create_backend(&Device::Cpu); + let in_features = 8; + let out_features = 2; + + // All nibbles=2 + let packed = Tensor::from_vec( + vec![0x22222222u32; out_features], + (out_features, 1), + &Device::Cpu, + ) + .unwrap(); + let scales = Tensor::from_vec(vec![0.5f32; out_features], (out_features, 1), &Device::Cpu) + .unwrap() + .to_dtype(DType::F16) + .unwrap(); + let biases = Tensor::from_vec( + vec![-1.0f32; out_features], + (out_features, 1), + &Device::Cpu, + ) + .unwrap() + .to_dtype(DType::F16) + .unwrap(); + + let lw = LinearWeight::quantized(packed, scales, biases, 8); + let x = Tensor::from_vec(vec![1.0f32; in_features], (1, in_features), &Device::Cpu) + .unwrap() + .to_dtype(DType::F16) + .unwrap(); + let out = lw.forward(&x, None, &*backend).unwrap(); + let vals: Vec = out + .to_dtype(DType::F32) + .unwrap() + .flatten_all() + .unwrap() + .to_vec1() + .unwrap(); + for (i, &v) in vals.iter().enumerate() { + assert!( + v.abs() < 0.5, + "output[{i}] = {v}, expected ~0.0 (all weights should be 0)" + ); + } +} + +#[test] +fn test_linear_weight_quantized_with_bias() { + // Verify bias is correctly added after quantized matmul + use cake_core::utils::quantized_linear::LinearWeight; + + let backend = cake_core::backends::create_backend(&Device::Cpu); + // All nibbles=0, scale=0, bias=0 → zero weights → output from matmul = 0 + // Then add a linear bias of [10.0, 20.0] → final output should be [10, 20] + let packed = Tensor::from_vec(vec![0u32; 2], (2, 1), &Device::Cpu).unwrap(); + let scales = Tensor::from_vec(vec![0f32; 2], (2, 1), &Device::Cpu) + .unwrap() + .to_dtype(DType::F16) + .unwrap(); + let biases = Tensor::from_vec(vec![0f32; 2], (2, 1), &Device::Cpu) + .unwrap() + .to_dtype(DType::F16) + .unwrap(); + let lw = LinearWeight::quantized(packed, scales, biases, 8); + + let x = Tensor::from_vec(vec![1.0f32; 8], (1, 8), &Device::Cpu) + .unwrap() + .to_dtype(DType::F16) + .unwrap(); + let linear_bias = Tensor::from_vec(vec![10.0f32, 20.0], 2, &Device::Cpu) + .unwrap() + .to_dtype(DType::F16) + .unwrap(); + let out = lw.forward(&x, Some(&linear_bias), &*backend).unwrap(); + let vals: Vec = out + .to_dtype(DType::F32) + .unwrap() + .flatten_all() + .unwrap() + .to_vec1() + .unwrap(); + assert!( + (vals[0] - 10.0).abs() < 0.5, + "expected ~10.0, got {}", + vals[0] + ); + assert!( + (vals[1] - 20.0).abs() < 0.5, + "expected ~20.0, got {}", + vals[1] + ); +} + +#[cfg(feature = "metal")] +#[test] +fn test_linear_weight_quantized_metal() { + // Test the full fused q4 path through LinearWeight on Metal + use candle_core::utils::metal_is_available; + use cake_core::utils::quantized_linear::LinearWeight; + use half::f16; + + if !metal_is_available() { + return; + } + let metal_device = match Device::new_metal(0) { + Ok(d) => d, + Err(_) => return, + }; + let backend = cake_core::backends::create_backend(&metal_device); + + let m = 2usize; + let in_features = 32usize; + let out_features = 4usize; + let group_size = 32usize; + let num_groups = 1usize; + let packed_cols = in_features / 8; // 4 + + // All nibbles=1, scale=1.0, bias=0.0 → all weights = 1.0 + // x = [1.0; 32] → dot = 32.0 + let packed_data = vec![0x11111111u32; out_features * packed_cols]; + let scales_f16 = vec![f16::from_f32(1.0); out_features * num_groups]; + let biases_f16 = vec![f16::from_f32(0.0); out_features * num_groups]; + + let packed_tensor = + Tensor::from_vec(packed_data, (out_features, packed_cols), &metal_device).unwrap(); + let scales_tensor = + Tensor::from_vec(scales_f16, (out_features, num_groups), &metal_device).unwrap(); + let biases_tensor = + Tensor::from_vec(biases_f16, (out_features, num_groups), &metal_device).unwrap(); + + let lw = LinearWeight::quantized(packed_tensor, scales_tensor, biases_tensor, group_size); + + let x = Tensor::from_vec(vec![f16::from_f32(1.0); m * in_features], (m, in_features), &metal_device).unwrap(); + let result = lw.forward(&x, None, &*backend).unwrap(); + + assert_eq!(result.dims(), &[m, out_features]); + let result_f32: Vec = result + .to_device(&Device::Cpu) + .unwrap() + .to_dtype(DType::F32) + .unwrap() + .flatten_all() + .unwrap() + .to_vec1() + .unwrap(); + + for (i, &v) in result_f32.iter().enumerate() { + assert!( + (v - 32.0).abs() < 1.0, + "output[{i}] = {v}, expected ~32.0" + ); + } +} + +#[cfg(feature = "metal")] +#[test] +fn test_linear_weight_quantized_metal_3d_input() { + // Test the Metal q4 path with 3D input (batch, seq, features) to verify + // the batched reshape in q4_linear_forward works. + use candle_core::utils::metal_is_available; + use cake_core::utils::quantized_linear::LinearWeight; + use half::f16; + + if !metal_is_available() { + return; + } + let metal_device = match Device::new_metal(0) { + Ok(d) => d, + Err(_) => return, + }; + let backend = cake_core::backends::create_backend(&metal_device); + + let batch = 1usize; + let seq_len = 4usize; + let in_features = 16usize; + let out_features = 2usize; + let group_size = 8usize; + let num_groups = 2usize; + let packed_cols = in_features / 8; // 2 + + // All nibbles=1, scale=1.0, bias=0.0 → weights = 1.0 + // x = [1.0; 16] → each row dot = 16.0 + let packed_data = vec![0x11111111u32; out_features * packed_cols]; + let scales_f16 = vec![f16::from_f32(1.0); out_features * num_groups]; + let biases_f16 = vec![f16::from_f32(0.0); out_features * num_groups]; + + let packed_tensor = + Tensor::from_vec(packed_data, (out_features, packed_cols), &metal_device).unwrap(); + let scales_tensor = + Tensor::from_vec(scales_f16, (out_features, num_groups), &metal_device).unwrap(); + let biases_tensor = + Tensor::from_vec(biases_f16, (out_features, num_groups), &metal_device).unwrap(); + + let lw = LinearWeight::quantized(packed_tensor, scales_tensor, biases_tensor, group_size); + + // 3D input: (1, 4, 16) + let x = Tensor::from_vec( + vec![f16::from_f32(1.0); batch * seq_len * in_features], + (batch, seq_len, in_features), + &metal_device, + ) + .unwrap(); + let result = lw.forward(&x, None, &*backend).unwrap(); + + assert_eq!(result.dims(), &[batch, seq_len, out_features]); + let result_f32: Vec = result + .to_device(&Device::Cpu) + .unwrap() + .to_dtype(DType::F32) + .unwrap() + .flatten_all() + .unwrap() + .to_vec1() + .unwrap(); + + for (i, &v) in result_f32.iter().enumerate() { + assert!( + (v - 16.0).abs() < 1.0, + "output[{i}] = {v}, expected ~16.0" + ); + } +} From 8f45e75ea2e40dcb703a0bf82507149f0c6267ae Mon Sep 17 00:00:00 2001 From: cjchanh Date: Thu, 30 Apr 2026 16:16:01 -0600 Subject: [PATCH 2/3] chore(cake): apply spec 151 portfolio gitignore sweep (q4-metal-patchset branch) Adds 4 portfolio-wide patterns (release/, release/evidence/autopilot_*/, __pycache__/, CRAFT_GATE_RESULT.json) to .gitignore. Mirrors the same sweep applied on cake/main this session (commit 786006d). Branch-scope to avoid cross-branch divergence. --- .gitignore | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.gitignore b/.gitignore index b3c46bf..85b7858 100644 --- a/.gitignore +++ b/.gitignore @@ -98,3 +98,9 @@ topology-*.yml # Autoresearch generated files (machine-specific) autoresearch/**/baseline.txt autoresearch/**/experiments.tsv + +# auto: spec 151 — portfolio trivial gitignore sweep +release/ +release/evidence/autopilot_*/ +__pycache__/ +CRAFT_GATE_RESULT.json From 22de82b588beaedf4f36e90dcdab1d3ac0715bfa Mon Sep 17 00:00:00 2001 From: cjchanh Date: Thu, 30 Apr 2026 16:16:01 -0600 Subject: [PATCH 3/3] fix(cake-q4): filter single-file safetensors push by layer (spec 199 row resolution) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mobile workers receiving a single-file `.safetensors` model previously got the FULL file regardless of layer assignment. On 4 GiB single-file models (Qwen2.5-7B-Instruct-4bit) this exceeded iPad jetsam budgets and crashed with `early eof`. Same root cause as PR #83 against cake/main, but applied here on q4-metal-patchset (PR #82's source branch) since PR #83's branch (`fix/single-file-layer-filter` at ee01115) has API drift against current upstream and isn't cleanly rebasable. Changes: * cake-core/src/utils/split.rs: - extract `reduce_for_layers(&Index, &[String])` from the worker- specific `reduce_for_worker` (more general, layer-list-driven) - introduce `ReducedModelBundle { index_json, safetensors }` for the reduced-bundle return type - add `build_reduced_single_file_bundle(model_path, layers)` that reads the safetensors header, filters tensors by layer prefixes, and emits a minimal safetensors blob + matching index.json * cake-core/src/cake/sharding/mod.rs: - replace the single-file fallback (which pushed the full model regardless of layer) with the reduced-bundle path - generalize `inline_files: HashMap>` so both the indexed and single-file paths can stream multiple inline blobs (index + reduced safetensors) - import `HashMap` (already had `HashSet`) Test coverage and benchmark updates pair with this in the existing q4-metal-patchset commits. Closes spec 199's cake-q4-branch NEEDS_OPERATOR_DECISION row with disposition: COMMIT (intentional q4 follow-up; preserves PR #82 contribution path; commit stays local until operator authorizes fork push). Spec: 199-triage-dirty-trees-across-active-portfolio (cake-q4-branch row) SOP: ~/Documents/Centennial/SOPs/CDS_Stuck_Spec_Triage_SOP_v1.md §3.A v1.1 Triage report: ~/ai/evidence/spec-096-triage-20260430/TRIAGE_REPORT.md --- cake-core/src/cake/sharding/mod.rs | 59 +++++++++++++------ cake-core/src/utils/split.rs | 94 +++++++++++++++++++++++++++--- 2 files changed, 126 insertions(+), 27 deletions(-) diff --git a/cake-core/src/cake/sharding/mod.rs b/cake-core/src/cake/sharding/mod.rs index f761e30..ac438cf 100644 --- a/cake-core/src/cake/sharding/mod.rs +++ b/cake-core/src/cake/sharding/mod.rs @@ -24,7 +24,7 @@ pub use client::*; pub use proto::*; pub use worker::*; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::io::Write; use std::path::{Path, PathBuf}; use std::time::{Duration, Instant}; @@ -498,7 +498,7 @@ async fn push_model_data( // Determine which safetensors shard files contain the assigned layers let index_path = model_path.join("model.safetensors.index.json"); - let mut filtered_index: Option> = None; + let mut inline_files: HashMap> = HashMap::new(); if index_path.exists() { files_to_send.push(index_path.clone()); let index_data = std::fs::read(&index_path)?; @@ -531,7 +531,10 @@ async fn push_model_data( serde_json::Value::Object(needed_weights), ); } - filtered_index = Some(serde_json::to_vec_pretty(&index_json)?); + inline_files.insert( + "model.safetensors.index.json".to_string(), + serde_json::to_vec_pretty(&index_json)?, + ); log::info!( "[{}] pushing {} shard file(s) + config + tokenizer + index", @@ -543,10 +546,19 @@ async fn push_model_data( files_to_send.push(model_path.join(shard)); } } else { - // Single safetensors file + // Single safetensors file: generate a reduced bundle for just the + // assigned layers so mobile workers do not receive the full model. let single = model_path.join("model.safetensors"); if single.exists() { - files_to_send.push(single); + let bundle = crate::utils::split::build_reduced_single_file_bundle(model_path, layers)?; + files_to_send.push(model_path.join("model.safetensors.index.json")); + files_to_send.push(model_path.join("reduced.safetensors")); + inline_files.insert("model.safetensors.index.json".to_string(), bundle.index_json); + inline_files.insert("reduced.safetensors".to_string(), bundle.safetensors); + log::info!( + "[{}] pushing reduced single-file bundle + config + tokenizer + index", + worker_name + ); } } @@ -561,20 +573,9 @@ async fn push_model_data( .to_string_lossy() .to_string(); - // Use filtered index if this is the index file (small, keep in-memory) - let is_index = filename == "model.safetensors.index.json"; - let small_data = if is_index { - if let Some(ref data) = filtered_index { - Some(data.clone()) - } else { - Some( - std::fs::read(file_path) - .map_err(|e| anyhow!("failed to read {}: {}", file_path.display(), e))?, - ) - } - } else { - None - }; + // Small generated files (filtered index / reduced bundle) are sent + // directly from memory rather than read back from disk. + let small_data = inline_files.get(&filename).cloned(); let total_size = if let Some(ref data) = small_data { data.len() as u64 @@ -1135,6 +1136,26 @@ mod tests { assert!(has_valid_model_cache(tmp.path(), &layers)); } + #[test] + fn has_valid_model_cache_reduced_single_file_bundle() { + let tmp = tempfile::tempdir().unwrap(); + fs::write(tmp.path().join("config.json"), "{}").unwrap(); + let index = serde_json::json!({ + "weight_map": { + "model.layers.0.attn.weight": "reduced.safetensors" + } + }); + fs::write( + tmp.path().join("model.safetensors.index.json"), + serde_json::to_string(&index).unwrap(), + ) + .unwrap(); + fs::write(tmp.path().join("reduced.safetensors"), "data").unwrap(); + + let layers = vec!["model.layers.0".to_string()]; + assert!(has_valid_model_cache(tmp.path(), &layers)); + } + #[test] fn has_valid_model_cache_sharded_missing_layer() { let tmp = tempfile::tempdir().unwrap(); diff --git a/cake-core/src/utils/split.rs b/cake-core/src/utils/split.rs index b74d47d..169c416 100644 --- a/cake-core/src/utils/split.rs +++ b/cake-core/src/utils/split.rs @@ -34,6 +34,11 @@ struct TensorStore { data: Vec, } +pub(crate) struct ReducedModelBundle { + pub index_json: Vec, + pub safetensors: Vec, +} + impl View for TensorStore { fn dtype(&self) -> Dtype { self.dtype @@ -85,17 +90,15 @@ fn load_index(data_path: &Path) -> Result { } } -fn reduce_for_worker( - index: &Index, - worker: &Node, -) -> Result<(Index, HashMap>)> { - log::info!("worker: {}", &worker.host); - +fn reduce_for_layers(index: &Index, layers: &[String]) -> (Index, HashMap>) { let mut reduced: HashMap> = HashMap::new(); let mut new_index = Index::new(); for (layer_full_name, filename) in &index.weight_map { - if worker.is_text_model_layer_owner(layer_full_name) { + let is_owned = layers + .iter() + .any(|layer| layer_full_name.starts_with(&format!("{}.", layer))); + if is_owned { if let Some(layers) = reduced.get_mut(filename) { layers.push(layer_full_name.to_string()); } else { @@ -109,7 +112,12 @@ fn reduce_for_worker( } } - Ok((new_index, reduced)) + (new_index, reduced) +} + +fn reduce_for_worker(index: &Index, worker: &Node) -> Result<(Index, HashMap>)> { + log::info!("worker: {}", &worker.host); + Ok(reduce_for_layers(index, &worker.layers)) } fn create_new_metadata( @@ -148,6 +156,23 @@ fn create_new_metadata( Ok(metadata) } +pub(crate) fn build_reduced_single_file_bundle( + data_path: &Path, + layers: &[String], +) -> Result { + let index = load_index(data_path)?; + let (new_index, reduced) = reduce_for_layers(&index, layers); + let metadata = create_new_metadata(data_path, &reduced)?; + + let index_json = serde_json::to_vec_pretty(&new_index)?; + let safetensors = safetensors::serialize(metadata, None)?; + + Ok(ReducedModelBundle { + index_json, + safetensors, + }) +} + /// Split a model into per-worker bundles. /// /// Each bundle contains a reduced safetensors file with only the worker's assigned tensors, @@ -367,4 +392,57 @@ mod tests { assert_eq!(deserialized.weight_map.len(), 1); assert_eq!(deserialized.weight_map["tensor.weight"], "file.safetensors"); } + + #[test] + fn build_reduced_single_file_bundle_extracts_only_selected_layers() { + let tmp = tempfile::tempdir().unwrap(); + + let mut metadata = HashMap::new(); + metadata.insert( + "model.layers.0.attn.weight".to_string(), + TensorStore { + dtype: Dtype::F32, + shape: vec![1], + data: 1.0f32.to_le_bytes().to_vec(), + }, + ); + metadata.insert( + "model.layers.1.attn.weight".to_string(), + TensorStore { + dtype: Dtype::F32, + shape: vec![1], + data: 2.0f32.to_le_bytes().to_vec(), + }, + ); + metadata.insert( + "model.embed_tokens.weight".to_string(), + TensorStore { + dtype: Dtype::F32, + shape: vec![1], + data: 3.0f32.to_le_bytes().to_vec(), + }, + ); + + let single_path = tmp.path().join("model.safetensors"); + safetensors::serialize_to_file(metadata, None, &single_path).unwrap(); + + let bundle = + build_reduced_single_file_bundle(tmp.path(), &["model.layers.0".to_string()]).unwrap(); + + let index_json: serde_json::Value = serde_json::from_slice(&bundle.index_json).unwrap(); + let weight_map = index_json["weight_map"].as_object().unwrap(); + assert_eq!(weight_map.len(), 1); + assert_eq!( + weight_map["model.layers.0.attn.weight"].as_str().unwrap(), + "reduced.safetensors" + ); + + let tensors = SafeTensors::deserialize(&bundle.safetensors).unwrap(); + let tensor_names: Vec<_> = tensors + .tensors() + .into_iter() + .map(|(name, _)| name) + .collect(); + assert_eq!(tensor_names, vec!["model.layers.0.attn.weight"]); + } }