diff --git a/cake-cli/src/main.rs b/cake-cli/src/main.rs index 7bc79a3..3c99c74 100644 --- a/cake-cli/src/main.rs +++ b/cake-cli/src/main.rs @@ -308,6 +308,6 @@ fn build_worker_topology( /// Return the base cache directory for zero-config model data. fn cache_base_dir() -> PathBuf { dirs::cache_dir() - .unwrap_or_else(|| PathBuf::from("/tmp")) + .unwrap_or_else(std::env::temp_dir) .join("cake") } diff --git a/cake-core/src/cake/discovery.rs b/cake-core/src/cake/discovery.rs index 19c3440..b5c3ac5 100644 --- a/cake-core/src/cake/discovery.rs +++ b/cake-core/src/cake/discovery.rs @@ -243,10 +243,26 @@ pub fn detect_backend() -> String { /// Detect the CUDA toolkit version via nvcc. pub fn detect_cuda_version() -> Option { + // Try nvcc from PATH first. let output = std::process::Command::new("nvcc") .arg("--version") - .output() - .ok()?; + .output(); + + // On Windows, nvcc may not be in PATH but CUDA_PATH is always set by the installer. + // Only fall back to CUDA_PATH if the PATH lookup failed. + #[cfg(target_os = "windows")] + let output = output.or_else(|_| { + let cuda_path = std::env::var("CUDA_PATH") + .map_err(|_| std::io::Error::new(std::io::ErrorKind::NotFound, "CUDA_PATH not set"))?; + let nvcc_path = std::path::PathBuf::from(&cuda_path) + .join("bin") + .join("nvcc.exe"); + std::process::Command::new(nvcc_path) + .arg("--version") + .output() + }); + + let output = output.ok()?; if !output.status.success() { return None; } @@ -330,6 +346,24 @@ fn detect_system_memory() -> u64 { } } + // On Windows, get total physical memory. + // Try PowerShell/CIM first (wmic is deprecated and removed in some Win11 builds). + #[cfg(target_os = "windows")] + { + if let Ok(output) = std::process::Command::new("powershell") + .args(["-NoProfile", "-Command", + "(Get-CimInstance Win32_ComputerSystem).TotalPhysicalMemory"]) + .output() + { + if output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout); + if let Ok(bytes) = stdout.trim().parse::() { + return bytes; + } + } + } + } + // On Linux and Android, read /proc/meminfo // Note: Android uses target_os = "android", not "linux", so both are listed. #[cfg(any(target_os = "linux", target_os = "android"))] @@ -530,6 +564,52 @@ fn get_broadcast_addresses() -> Vec { } } + // On Windows, parse `ipconfig` to compute broadcast addresses from IP + subnet mask. + // ipconfig groups lines under adapter headers (non-indented). Indented lines contain + // the IPv4 address and subnet mask. We reset last_ip on adapter boundaries so a + // stale IP from a previous adapter can't pair with the wrong subnet mask. + #[cfg(target_os = "windows")] + { + if let Ok(output) = std::process::Command::new("ipconfig").output() { + let stdout = String::from_utf8_lossy(&output.stdout); + let mut last_ip: Option = None; + for line in stdout.lines() { + // Adapter headers have no leading whitespace; detail lines are indented. + // Reset state on adapter boundaries to prevent cross-adapter mispairing. + if !line.starts_with(' ') && !line.starts_with('\t') { + last_ip = None; + } + + let trimmed = line.trim(); + // Match lines containing an IPv4 address value (x.x.x.x after a colon). + // This is locale-independent: we look for ": " on any line and + // distinguish address vs mask by checking if it looks like a mask + // (starts with 255.). + if let Some(colon_idx) = trimmed.rfind(':') { + let value = trimmed[colon_idx + 1..].trim(); + if let Ok(ip) = value.parse::() { + let octets = ip.octets(); + if octets[0] == 255 { + // This is a subnet mask — pair with last_ip + if let Some(addr) = last_ip.take() { + let ip_bits = u32::from(addr); + let mask_bits = u32::from(ip); + let brd = (ip_bits & mask_bits) | (!mask_bits); + let brd_addr = Ipv4Addr::from(brd); + if !brd_addr.is_loopback() { + addrs.push(brd_addr); + } + } + } else if !ip.is_loopback() { + // This is an IPv4 address + last_ip = Some(ip); + } + } + } + } + } + } + // Always include the limited broadcast as a fallback addrs.push(Ipv4Addr::BROADCAST); addrs.dedup(); diff --git a/cake-core/src/utils/models.rs b/cake-core/src/utils/models.rs index 4d4245d..d84b22f 100644 --- a/cake-core/src/utils/models.rs +++ b/cake-core/src/utils/models.rs @@ -171,7 +171,7 @@ fn hf_cache_dir() -> Option { /// Return the Cake cluster cache directory if it exists. fn cake_cache_dir() -> Option { let cache = dirs::cache_dir() - .unwrap_or_else(|| PathBuf::from("/tmp")) + .unwrap_or_else(std::env::temp_dir) .join("cake"); if cache.exists() { Some(cache) diff --git a/docs/install.md b/docs/install.md index cdb12a4..9af472a 100644 --- a/docs/install.md +++ b/docs/install.md @@ -48,6 +48,14 @@ version supported by your driver to avoid library version mismatches: CUDA_HOME=/usr/local/cuda-12.4 cargo build --release --features cuda ``` +**CUDA on Windows:** + +Windows workers require an NVIDIA GPU driver and the [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads) >= 12.2 (the installer sets `CUDA_PATH` automatically). + +```powershell +cargo build --release --features cuda +``` + ### Pre-Volta NVIDIA GPUs For Pascal, Maxwell, or other GPUs with compute capability < 7.0, the upstream `candle-kernels` crate requires patches. See [`cuda-compat/`](../cuda-compat/) for a one-command fix: