Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cake-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,6 @@ fn build_worker_topology(
/// Return the base cache directory for zero-config model data.
fn cache_base_dir() -> PathBuf {
dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from("/tmp"))
.unwrap_or_else(std::env::temp_dir)
.join("cake")
}
84 changes: 82 additions & 2 deletions cake-core/src/cake/discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,26 @@ pub fn detect_backend() -> String {

/// Detect the CUDA toolkit version via nvcc.
pub fn detect_cuda_version() -> Option<String> {
// Try nvcc from PATH first.
let output = std::process::Command::new("nvcc")
.arg("--version")
.output()
.ok()?;
.output();

// On Windows, nvcc may not be in PATH but CUDA_PATH is always set by the installer.
// Only fall back to CUDA_PATH if the PATH lookup failed.
#[cfg(target_os = "windows")]
let output = output.or_else(|_| {
let cuda_path = std::env::var("CUDA_PATH")
.map_err(|_| std::io::Error::new(std::io::ErrorKind::NotFound, "CUDA_PATH not set"))?;
let nvcc_path = std::path::PathBuf::from(&cuda_path)
.join("bin")
.join("nvcc.exe");
std::process::Command::new(nvcc_path)
.arg("--version")
.output()
});

let output = output.ok()?;
if !output.status.success() {
return None;
}
Expand Down Expand Up @@ -330,6 +346,24 @@ fn detect_system_memory() -> u64 {
}
}

// On Windows, get total physical memory.
// Try PowerShell/CIM first (wmic is deprecated and removed in some Win11 builds).
#[cfg(target_os = "windows")]
{
if let Ok(output) = std::process::Command::new("powershell")
.args(["-NoProfile", "-Command",
"(Get-CimInstance Win32_ComputerSystem).TotalPhysicalMemory"])
.output()
{
if output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
if let Ok(bytes) = stdout.trim().parse::<u64>() {
return bytes;
}
}
}
}

// On Linux and Android, read /proc/meminfo
// Note: Android uses target_os = "android", not "linux", so both are listed.
#[cfg(any(target_os = "linux", target_os = "android"))]
Expand Down Expand Up @@ -530,6 +564,52 @@ fn get_broadcast_addresses() -> Vec<Ipv4Addr> {
}
}

// On Windows, parse `ipconfig` to compute broadcast addresses from IP + subnet mask.
// ipconfig groups lines under adapter headers (non-indented). Indented lines contain
// the IPv4 address and subnet mask. We reset last_ip on adapter boundaries so a
// stale IP from a previous adapter can't pair with the wrong subnet mask.
#[cfg(target_os = "windows")]
{
if let Ok(output) = std::process::Command::new("ipconfig").output() {
let stdout = String::from_utf8_lossy(&output.stdout);
let mut last_ip: Option<Ipv4Addr> = None;
for line in stdout.lines() {
// Adapter headers have no leading whitespace; detail lines are indented.
// Reset state on adapter boundaries to prevent cross-adapter mispairing.
if !line.starts_with(' ') && !line.starts_with('\t') {
last_ip = None;
}

let trimmed = line.trim();
// Match lines containing an IPv4 address value (x.x.x.x after a colon).
// This is locale-independent: we look for ": <ipv4>" on any line and
// distinguish address vs mask by checking if it looks like a mask
// (starts with 255.).
if let Some(colon_idx) = trimmed.rfind(':') {
let value = trimmed[colon_idx + 1..].trim();
if let Ok(ip) = value.parse::<Ipv4Addr>() {
let octets = ip.octets();
if octets[0] == 255 {
// This is a subnet mask — pair with last_ip
if let Some(addr) = last_ip.take() {
let ip_bits = u32::from(addr);
let mask_bits = u32::from(ip);
let brd = (ip_bits & mask_bits) | (!mask_bits);
let brd_addr = Ipv4Addr::from(brd);
if !brd_addr.is_loopback() {
addrs.push(brd_addr);
}
}
} else if !ip.is_loopback() {
// This is an IPv4 address
last_ip = Some(ip);
}
}
}
}
}
}

// Always include the limited broadcast as a fallback
addrs.push(Ipv4Addr::BROADCAST);
addrs.dedup();
Expand Down
2 changes: 1 addition & 1 deletion cake-core/src/utils/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ fn hf_cache_dir() -> Option<PathBuf> {
/// Return the Cake cluster cache directory if it exists.
fn cake_cache_dir() -> Option<PathBuf> {
let cache = dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from("/tmp"))
.unwrap_or_else(std::env::temp_dir)
.join("cake");
if cache.exists() {
Some(cache)
Expand Down
8 changes: 8 additions & 0 deletions docs/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ version supported by your driver to avoid library version mismatches:
CUDA_HOME=/usr/local/cuda-12.4 cargo build --release --features cuda
```

**CUDA on Windows:**

Windows workers require an NVIDIA GPU driver and the [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads) >= 12.2 (the installer sets `CUDA_PATH` automatically).

```powershell
cargo build --release --features cuda
```

### Pre-Volta NVIDIA GPUs

For Pascal, Maxwell, or other GPUs with compute capability < 7.0, the upstream `candle-kernels` crate requires patches. See [`cuda-compat/`](../cuda-compat/) for a one-command fix:
Expand Down