Skip to content

Commit 639a18b

Browse files
hczphnclaude
andcommitted
Reduce peak memory during prove by releasing witness shared memory early
Add witness_ack shared memory signaling between client and server: - Client resets a 1-byte ack signal before writing witness - Server signals ack after reading witness into MPI shared memory - Client polls for ack, then immediately releases witness shared memory and calls malloc_trim to return memory to OS - Prove request runs concurrently via tokio async, so witness memory is freed while proving is in progress - Skip reading PCS setup from shared memory (return default) since the client does not need it after setup Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 16839e4 commit 639a18b

3 files changed

Lines changed: 83 additions & 3 deletions

File tree

expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ where
8787
C: GKREngine,
8888
ECCConfig: Config<FieldConfig = C::FieldConfig>,
8989
{
90-
let setup_timer = Timer::new("setup", true);
90+
let setup_timer = Timer::new("new setup", true);
9191
println!("Starting server with binary: {server_binary}");
9292

9393
let mut bytes = vec![];
@@ -140,7 +140,11 @@ where
140140

141141
setup_timer.stop();
142142

143-
SharedMemoryEngine::read_pcs_setup_from_shared_memory()
143+
// Skip reading PCS setup from shared memory; return default to reduce memory
144+
(
145+
ExpanderProverSetup::default(),
146+
ExpanderVerifierSetup::default(),
147+
)
144148
}
145149

146150
pub fn client_send_witness_and_prove<C, ECCConfig>(
@@ -152,8 +156,37 @@ where
152156
{
153157
let timer = Timer::new("prove", true);
154158

159+
// Reset ack signal, then write witness
160+
SharedMemoryEngine::reset_witness_ack();
155161
SharedMemoryEngine::write_witness_to_shared_memory::<C::FieldConfig>(device_memories);
156-
wait_async(ClientHttpHelper::request_prove());
162+
163+
extern "C" {
164+
fn malloc_trim(pad: usize) -> i32;
165+
}
166+
unsafe {
167+
malloc_trim(0);
168+
}
169+
170+
// Async: send prove request + poll for witness ack to release shared memory early
171+
let rt = tokio::runtime::Runtime::new().unwrap();
172+
rt.block_on(async {
173+
let prove_handle = tokio::spawn(async {
174+
ClientHttpHelper::request_prove().await;
175+
});
176+
177+
// Poll witness_ack; once server confirms read, release witness shared memory
178+
tokio::task::spawn_blocking(|| {
179+
SharedMemoryEngine::wait_for_witness_read_complete();
180+
unsafe {
181+
super::shared_memory_utils::SHARED_MEMORY.witness = None;
182+
malloc_trim(0);
183+
}
184+
})
185+
.await
186+
.expect("Witness cleanup task failed");
187+
188+
prove_handle.await.expect("Prove task failed");
189+
});
157190

158191
let proof = SharedMemoryEngine::read_proof_from_shared_memory();
159192

expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ where
149149
let mut witness_win = state.wt_shared_memory_win.lock().await;
150150
S::setup_shared_witness(&state.global_mpi_config, &mut witness, &mut witness_win);
151151

152+
// Signal client: witness has been read, shared memory can be released
153+
SharedMemoryEngine::signal_witness_read_complete();
154+
152155
let prover_setup_guard = state.prover_setup.lock().await;
153156
let computation_graph = state.computation_graph.lock().await;
154157

expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@ pub struct SharedMemory {
1818
pub pcs_setup: Option<Shmem>,
1919
pub witness: Option<Shmem>,
2020
pub proof: Option<Shmem>,
21+
/// 1-byte signal: 0 = witness not read, 1 = server finished reading witness
22+
pub witness_ack: Option<Shmem>,
2123
}
2224

2325
pub static mut SHARED_MEMORY: SharedMemory = SharedMemory {
2426
pcs_setup: None,
2527
witness: None,
2628
proof: None,
29+
witness_ack: None,
2730
};
2831

2932
pub struct SharedMemoryEngine {}
@@ -106,6 +109,47 @@ impl SharedMemoryEngine {
106109
Self::read_object_from_shared_memory("pcs_setup", 0)
107110
}
108111

112+
/// Client: reset witness_ack to 0 (call before writing witness)
113+
pub fn reset_witness_ack() {
114+
unsafe {
115+
Self::allocate_shared_memory_if_necessary(
116+
&mut SHARED_MEMORY.witness_ack,
117+
"witness_ack",
118+
1,
119+
);
120+
let ptr = SHARED_MEMORY.witness_ack.as_mut().unwrap().as_ptr();
121+
std::ptr::write_volatile(ptr, 0u8);
122+
}
123+
}
124+
125+
/// Server: set witness_ack to 1 (call after reading witness)
126+
pub fn signal_witness_read_complete() {
127+
let shmem = ShmemConf::new()
128+
.flink("witness_ack")
129+
.open()
130+
.expect("Failed to open witness_ack shared memory");
131+
unsafe {
132+
std::ptr::write_volatile(shmem.as_ptr(), 1u8);
133+
}
134+
}
135+
136+
/// Client: poll until witness_ack becomes 1
137+
pub fn wait_for_witness_read_complete() {
138+
unsafe {
139+
let ptr = SHARED_MEMORY
140+
.witness_ack
141+
.as_ref()
142+
.expect("witness_ack not initialized, call reset_witness_ack first")
143+
.as_ptr() as *const u8;
144+
loop {
145+
if std::ptr::read_volatile(ptr) != 0 {
146+
break;
147+
}
148+
std::thread::sleep(std::time::Duration::from_millis(500));
149+
}
150+
}
151+
}
152+
109153
pub fn write_witness_to_shared_memory<F: FieldEngine>(values: Vec<Vec<F::SimdCircuitField>>) {
110154
let total_size = std::mem::size_of::<usize>()
111155
+ values

0 commit comments

Comments
 (0)