Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,9 @@ topology-*.yml
# Autoresearch generated files (machine-specific)
autoresearch/**/baseline.txt
autoresearch/**/experiments.tsv

# auto: spec 151 — portfolio trivial gitignore sweep
release/
release/evidence/autopilot_*/
__pycache__/
CRAFT_GATE_RESULT.json
1,474 changes: 1,268 additions & 206 deletions cake-core/src/backends/metal/mod.rs

Large diffs are not rendered by default.

362 changes: 362 additions & 0 deletions cake-core/src/backends/metal/ops.msl

Large diffs are not rendered by default.

64 changes: 42 additions & 22 deletions cake-core/src/backends/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub use cuda::CudaBackend;
mod metal;
#[cfg(feature = "metal")]
pub use self::metal::MetalBackend;
#[cfg(feature = "metal")]
pub use self::metal::q4_matmul_f16;

#[cfg(feature = "vulkan")]
mod vulkan;
Expand Down Expand Up @@ -87,13 +89,7 @@ pub trait ComputeBackend: Send + Sync + std::fmt::Debug {
// ── Fused normalization ──────────────────────────────────────────

/// `rms_norm(x, weight, eps) * silu(z)` — GDN output gating.
fn rms_norm_gated(
&self,
x: &Tensor,
z: &Tensor,
weight: &Tensor,
eps: f32,
) -> Result<Tensor>;
fn rms_norm_gated(&self, x: &Tensor, z: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor>;

/// `rms_norm(a + b, weight, eps)` — residual + norm fusion.
/// Returns `(residual, normed)` where `residual = a + b`.
Expand Down Expand Up @@ -194,6 +190,20 @@ pub trait ComputeBackend: Send + Sync + std::fmt::Debug {
Ok(weight.clone())
}

/// Pre-process and fuse multiple linear weights.
///
/// Default behavior preserves the existing semantics: concatenate the
/// original `(out_features, in_features)` weights first, then run the
/// backend-specific preprocessing step once on the fused tensor.
///
/// Backends that transpose or otherwise expand weights during
/// `preprocess_linear_weight` can override this to reduce peak memory by
/// preprocessing each part incrementally before concatenation.
fn preprocess_linear_weights(&self, weights: &[&Tensor]) -> Result<Tensor> {
let fused = Tensor::cat(weights, 0)?;
self.preprocess_linear_weight(&fused)
}

// ── Inference primitives ──────────────────────────────────────────

/// Linear layer forward: `x @ weight^T + bias`.
Expand All @@ -203,12 +213,7 @@ pub trait ComputeBackend: Send + Sync + std::fmt::Debug {
/// (avoids slow broadcast_matmul on CUDA/CPU)
/// - For non-contiguous 3D+: uses broadcast_left on weight
/// - No dtype conversion (caller is responsible)
fn linear_forward(
&self,
x: &Tensor,
weight: &Tensor,
bias: Option<&Tensor>,
) -> Result<Tensor> {
fn linear_forward(&self, x: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> Result<Tensor> {
let out = match x.dims() {
[b1, b2, m, k] => {
if x.is_contiguous() {
Expand Down Expand Up @@ -240,6 +245,27 @@ pub trait ComputeBackend: Send + Sync + std::fmt::Debug {
}
}

/// Fused 4-bit dequant + matmul: `output = x @ dequant(packed, scales, biases)^T`.
///
/// On Metal, this dispatches to the q4_matmul_f16 MSL kernel, keeping weights
/// at 0.5 bytes/element (4x memory reduction vs F16).
///
/// Default: dequantizes on CPU via `gptq::dequantize_packed_4bit` and calls
/// `linear_forward`. Suboptimal but correct — only Metal overrides this.
fn q4_linear_forward(
&self,
packed: &Tensor,
scales: &Tensor,
biases: &Tensor,
x: &Tensor,
group_size: usize,
) -> Result<Tensor> {
// CPU fallback: dequantize to F32, convert to input dtype, matmul.
let weight = crate::utils::gptq::dequantize_packed_4bit(packed, scales, biases, group_size)?;
let weight = weight.to_dtype(x.dtype())?.to_device(x.device())?;
self.linear_forward(x, &weight, None)
}

/// RMS normalization: `x * weight / sqrt(mean(x^2) + eps)`.
fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
candle_nn::ops::rms_norm(x, weight, eps)
Expand Down Expand Up @@ -285,9 +311,8 @@ pub trait ComputeBackend: Send + Sync + std::fmt::Debug {
match &b_data {
Some(bd) => {
for i in 0..hidden {
out[off + i] =
(((row[i] as f64 - mean) * rstd) * w_data[i] as f64
+ bd[i] as f64) as f32;
out[off + i] = (((row[i] as f64 - mean) * rstd) * w_data[i] as f64
+ bd[i] as f64) as f32;
}
}
None => {
Expand Down Expand Up @@ -530,12 +555,7 @@ pub trait ComputeBackend: Send + Sync + std::fmt::Debug {
/// Create a causal attention mask. Returns a U8 tensor of shape `(seq_len, kv_len)`
/// where 1 = masked (future position), 0 = attend.
/// Callers use `masked_fill` or `where_cond` to apply the mask.
fn causal_mask(
&self,
seq_len: usize,
kv_len: usize,
device: &Device,
) -> Result<Tensor> {
fn causal_mask(&self, seq_len: usize, kv_len: usize, device: &Device) -> Result<Tensor> {
if seq_len == 1 {
return Tensor::zeros((1, kv_len), DType::U8, device);
}
Expand Down
59 changes: 40 additions & 19 deletions cake-core/src/cake/sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub use client::*;
pub use proto::*;
pub use worker::*;

use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::io::Write;
use std::path::{Path, PathBuf};
use std::time::{Duration, Instant};
Expand Down Expand Up @@ -498,7 +498,7 @@ async fn push_model_data(

// Determine which safetensors shard files contain the assigned layers
let index_path = model_path.join("model.safetensors.index.json");
let mut filtered_index: Option<Vec<u8>> = None;
let mut inline_files: HashMap<String, Vec<u8>> = HashMap::new();
if index_path.exists() {
files_to_send.push(index_path.clone());
let index_data = std::fs::read(&index_path)?;
Expand Down Expand Up @@ -531,7 +531,10 @@ async fn push_model_data(
serde_json::Value::Object(needed_weights),
);
}
filtered_index = Some(serde_json::to_vec_pretty(&index_json)?);
inline_files.insert(
"model.safetensors.index.json".to_string(),
serde_json::to_vec_pretty(&index_json)?,
);

log::info!(
"[{}] pushing {} shard file(s) + config + tokenizer + index",
Expand All @@ -543,10 +546,19 @@ async fn push_model_data(
files_to_send.push(model_path.join(shard));
}
} else {
// Single safetensors file
// Single safetensors file: generate a reduced bundle for just the
// assigned layers so mobile workers do not receive the full model.
let single = model_path.join("model.safetensors");
if single.exists() {
files_to_send.push(single);
let bundle = crate::utils::split::build_reduced_single_file_bundle(model_path, layers)?;
files_to_send.push(model_path.join("model.safetensors.index.json"));
files_to_send.push(model_path.join("reduced.safetensors"));
inline_files.insert("model.safetensors.index.json".to_string(), bundle.index_json);
inline_files.insert("reduced.safetensors".to_string(), bundle.safetensors);
log::info!(
"[{}] pushing reduced single-file bundle + config + tokenizer + index",
worker_name
);
}
}

Expand All @@ -561,20 +573,9 @@ async fn push_model_data(
.to_string_lossy()
.to_string();

// Use filtered index if this is the index file (small, keep in-memory)
let is_index = filename == "model.safetensors.index.json";
let small_data = if is_index {
if let Some(ref data) = filtered_index {
Some(data.clone())
} else {
Some(
std::fs::read(file_path)
.map_err(|e| anyhow!("failed to read {}: {}", file_path.display(), e))?,
)
}
} else {
None
};
// Small generated files (filtered index / reduced bundle) are sent
// directly from memory rather than read back from disk.
let small_data = inline_files.get(&filename).cloned();

let total_size = if let Some(ref data) = small_data {
data.len() as u64
Expand Down Expand Up @@ -1135,6 +1136,26 @@ mod tests {
assert!(has_valid_model_cache(tmp.path(), &layers));
}

#[test]
fn has_valid_model_cache_reduced_single_file_bundle() {
let tmp = tempfile::tempdir().unwrap();
fs::write(tmp.path().join("config.json"), "{}").unwrap();
let index = serde_json::json!({
"weight_map": {
"model.layers.0.attn.weight": "reduced.safetensors"
}
});
fs::write(
tmp.path().join("model.safetensors.index.json"),
serde_json::to_string(&index).unwrap(),
)
.unwrap();
fs::write(tmp.path().join("reduced.safetensors"), "data").unwrap();

let layers = vec!["model.layers.0".to_string()];
assert!(has_valid_model_cache(tmp.path(), &layers));
}

#[test]
fn has_valid_model_cache_sharded_missing_layer() {
let tmp = tempfile::tempdir().unwrap();
Expand Down
Loading