Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions cake-core/src/cake/sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,53 @@ pub async fn master_setup(
&worker.host
);

let mut stream = TcpStream::connect(&worker.host)
.await
.map_err(|e| anyhow!("can't connect to {}: {}", &worker.host, e))?;
// Retry up to 3 times with exponential backoff (1 s, 2 s, 4 s).
// iOS workers may need a brief moment for the TCP listener to become
// fully reachable after the UDP discovery advertisement is sent.
// 10 s per attempt is enough for a LAN connection while still failing
// fast enough to give a useful error within ~30 s overall.
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const MAX_ATTEMPTS: u32 = 3;
let mut last_err: Option<anyhow::Error> = None;
let mut stream = None;
for attempt in 0..MAX_ATTEMPTS {
if attempt > 0 {
let delay = Duration::from_secs(1u64 << (attempt - 1));
log::warn!(
"retrying connection to '{}' at {} in {:.1}s (attempt {}/{}) ...",
&worker.name,
&worker.host,
delay.as_secs_f32(),
attempt + 1,
MAX_ATTEMPTS,
);
tokio::time::sleep(delay).await;
}
match tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect(&worker.host)).await {
Ok(Ok(s)) => {
stream = Some(s);
break;
}
Ok(Err(e)) => {
log::warn!(
"connect attempt {}/{} to '{}' at {} failed: {}",
attempt + 1, MAX_ATTEMPTS, &worker.name, &worker.host, e
);
last_err = Some(anyhow!("can't connect to {}: {}", &worker.host, e));
}
Err(_) => {
log::warn!(
"connect attempt {}/{} to '{}' at {} timed out after {:.0}s",
attempt + 1, MAX_ATTEMPTS, &worker.name, &worker.host,
CONNECT_TIMEOUT.as_secs_f32(),
);
last_err = Some(anyhow!("can't connect to {}: connection timed out", &worker.host));
}
}
}
let mut stream = stream.ok_or_else(|| {
last_err.unwrap_or_else(|| anyhow!("can't connect to {}: all attempts failed", &worker.host))
})?;
let _ = stream.set_nodelay(true);

// Mutual authentication
Expand Down
6 changes: 3 additions & 3 deletions cake-core/src/models/common/disk_expert_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,10 +511,10 @@ impl ExpertProvider for DiskExpertProvider {
for name in [&names.gate_proj, &names.up_proj, &names.down_proj] {
if let Some((bytes, _, _)) = self.storage.tensor_bytes(name) {
unsafe {
libc::posix_madvise(
bytes.as_ptr() as *mut _,
libc::madvise(
bytes.as_ptr() as *mut libc::c_void,
bytes.len(),
libc::POSIX_MADV_WILLNEED,
libc::MADV_WILLNEED,
);
}
}
Expand Down
9 changes: 4 additions & 5 deletions cake-core/src/utils/tensor_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,10 @@ impl MappedShard {
if ptr == libc::MAP_FAILED {
anyhow::bail!("mmap failed: {}", std::io::Error::last_os_error());
}
// Hint: sequential access pattern — triggers aggressive readahead on NVMe
unsafe { libc::posix_madvise(ptr, len, libc::POSIX_MADV_SEQUENTIAL); }
// Hint: sequential access pattern — triggers aggressive readahead on NVMe.
// Use madvise (POSIX.1-2003 base) which is available on all Unix platforms
// including Android. posix_madvise is not defined in Android's libc.
unsafe { libc::madvise(ptr, len, libc::MADV_SEQUENTIAL); }
Ok(Self { mmap_ptr: ptr as *const u8, mmap_len: len })
}

Expand Down Expand Up @@ -238,9 +240,6 @@ pub struct SafetensorsStorage {
index: HashMap<String, TensorMeta>,
/// Memory-mapped shard files (indexed by shard_idx in TensorMeta).
shards: Vec<MappedShard>,
/// File handles for non-mmap fallback.
#[cfg(not(unix))]
files: Vec<File>,
}

impl SafetensorsStorage {
Expand Down
5 changes: 5 additions & 0 deletions cake-mobile-app/iosApp/iosApp/Info.plist
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
<key>NSBonjourServices</key>
<array>
<string>_cake._udp</string>
<string>_cake._tcp</string>
</array>
<key>UIBackgroundModes</key>
<array>
<string>voip</string>
</array>
<key>CADisableMinimumFrameDurationOnPhone</key>
<true/>
Expand Down
37 changes: 31 additions & 6 deletions cake-mobile/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,15 +391,32 @@ async fn run_zero_config_worker(
update_status("loading", "Loading model weights...", 0.0);
log_mobile(&format!("[cake-mobile] creating Context::from_args (cpu={})...", force_cpu));

// Install a panic hook so that any panic during model loading is logged to
// the mobile log (visible in the app's diagnostic output).
let prev_hook = std::panic::take_hook();
std::panic::set_hook(Box::new(move |info| {
let msg = format!("[cake-mobile] PANIC: {}", info);
log_mobile(&msg);
}));

let ctx_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
Context::from_args(args)
}));
// Context::from_args is CPU- and I/O-intensive (loads model weight files).
// Run it on a dedicated blocking thread to avoid starving the Tokio async
// runtime while the listener is still open and waiting for the master's
// inference reconnect.
let ctx_result = match tokio::task::spawn_blocking(move || {
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| Context::from_args(args)))
})
.await
{
Ok(r) => r,
Err(join_err) => {
std::panic::set_hook(prev_hook);
let msg = format!("context loading task failed: {}", join_err);
log_mobile(&format!("[cake-mobile] ERROR: {}", msg));
update_status("error", &msg, 0.0);
return msg;
}
};

std::panic::set_hook(prev_hook);

Expand Down Expand Up @@ -456,15 +473,23 @@ async fn run_direct_worker(name: &str, model: &str, address: &str) -> String {
update_status("loading", "Downloading model...", 0.0);
log_mobile("[cake-mobile] creating context...");

let mut ctx = match Context::from_args(args) {
Ok(ctx) => {
// Context::from_args downloads / loads large model weight files. Run it
// on a dedicated blocking thread so the Tokio async runtime stays live.
let ctx_result = tokio::task::spawn_blocking(move || Context::from_args(args)).await;
let mut ctx = match ctx_result {
Ok(Ok(ctx)) => {
log_mobile(&format!("[cake-mobile] context created, device={:?}", ctx.device));
ctx
}
Err(e) => {
Ok(Err(e)) => {
update_status("error", &format!("Failed: {}", e), 0.0);
return format!("context creation failed: {}", e);
}
Err(join_err) => {
let msg = format!("context loading task failed: {}", join_err);
update_status("error", &msg, 0.0);
return msg;
}
};

update_status("serving", "Ready — serving inference", 1.0);
Expand Down
Loading