diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs index bf4a07c4..64b39c03 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs @@ -140,7 +140,11 @@ where setup_timer.stop(); - SharedMemoryEngine::read_pcs_setup_from_shared_memory() + // Prover setup not needed on client side (server does the proving). + // Verifier setup is required for verification, so read it from shared memory. + let (_prover_setup, verifier_setup) = + SharedMemoryEngine::read_pcs_setup_from_shared_memory::(); + (ExpanderProverSetup::default(), verifier_setup) } pub fn client_send_witness_and_prove( @@ -152,8 +156,39 @@ where { let timer = Timer::new("prove", true); + // Reset ack signal, then write witness + SharedMemoryEngine::reset_witness_ack(); SharedMemoryEngine::write_witness_to_shared_memory::(device_memories); - wait_async(ClientHttpHelper::request_prove()); + + #[cfg(all(target_os = "linux", target_env = "gnu"))] + { + extern "C" { + fn malloc_trim(pad: usize) -> i32; + } + unsafe { + malloc_trim(0); + } + } + + // Async: send prove request + poll for witness ack to release shared memory early + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let prove_handle = tokio::spawn(async { + ClientHttpHelper::request_prove().await; + }); + + // Poll witness_ack; once server confirms read, release witness shared memory + tokio::task::spawn_blocking(|| { + SharedMemoryEngine::wait_for_witness_read_complete(); + unsafe { + super::shared_memory_utils::SHARED_MEMORY.witness = None; + } + }) + .await + .expect("Witness cleanup task failed"); + + prove_handle.await.expect("Prove task failed"); + }); let proof = SharedMemoryEngine::read_proof_from_shared_memory(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs index 27919a50..f51dd509 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs @@ -149,6 +149,9 @@ where let mut witness_win = state.wt_shared_memory_win.lock().await; S::setup_shared_witness(&state.global_mpi_config, &mut witness, &mut witness_win); + // Signal client: witness has been read, shared memory can be released + SharedMemoryEngine::signal_witness_read_complete(); + let prover_setup_guard = state.prover_setup.lock().await; let computation_graph = state.computation_graph.lock().await; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs index 648f33a8..b03aa639 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs @@ -18,12 +18,15 @@ pub struct SharedMemory { pub pcs_setup: Option, pub witness: Option, pub proof: Option, + /// 1-byte signal: 0 = witness not read, 1 = server finished reading witness + pub witness_ack: Option, } pub static mut SHARED_MEMORY: SharedMemory = SharedMemory { pcs_setup: None, witness: None, proof: None, + witness_ack: None, }; pub struct SharedMemoryEngine {} @@ -106,6 +109,56 @@ impl SharedMemoryEngine { Self::read_object_from_shared_memory("pcs_setup", 0) } + /// Client: reset witness_ack to 0 (call before writing witness) + pub fn reset_witness_ack() { + unsafe { + Self::allocate_shared_memory_if_necessary( + &mut SHARED_MEMORY.witness_ack, + "witness_ack", + 1, + ); + let ptr = SHARED_MEMORY.witness_ack.as_mut().unwrap().as_ptr(); + std::ptr::write_volatile(ptr, 0u8); + } + } + + /// Server: set witness_ack to 1 (call after reading witness) + pub fn signal_witness_read_complete() { + let shmem = ShmemConf::new() + .flink("witness_ack") + .open() + .expect("Failed to open witness_ack shared memory"); + unsafe { + std::ptr::write_volatile(shmem.as_ptr(), 1u8); + } + } + + /// Client: poll until witness_ack becomes 1, with a timeout to avoid hanging + /// if the server crashes. + pub fn wait_for_witness_read_complete() { + const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300); + let start = std::time::Instant::now(); + unsafe { + let ptr = SHARED_MEMORY + .witness_ack + .as_ref() + .expect("witness_ack not initialized, call reset_witness_ack first") + .as_ptr() as *const u8; + loop { + if std::ptr::read_volatile(ptr) != 0 { + break; + } + if start.elapsed() > TIMEOUT { + panic!( + "Timed out waiting for server to read witness ({}s)", + TIMEOUT.as_secs() + ); + } + std::thread::sleep(std::time::Duration::from_millis(10)); + } + } + } + pub fn write_witness_to_shared_memory(values: Vec>) { let total_size = std::mem::size_of::() + values