diff --git a/circuit-std-rs/src/logup.rs b/circuit-std-rs/src/logup.rs index a88ba102..d99b8b7c 100644 --- a/circuit-std-rs/src/logup.rs +++ b/circuit-std-rs/src/logup.rs @@ -328,6 +328,35 @@ impl LogUpSingleKeyTable { assert_eq_rational(builder, &v_table, &v_query); } + + pub fn final_check_with_query_count>( + &mut self, + builder: &mut B, + query_count: &[Variable], + ) { + if self.table.is_empty() || self.query_keys.is_empty() { + panic!("empty table or empty query"); + } + + let value_len = self.table[0].len(); + + let alpha = builder.get_random_value(); + let randomness = get_column_randomness(builder, value_len); + + let table_combined = combine_columns(builder, &self.table, &randomness); + let v_table = logup_poly_val(builder, &table_combined, query_count, &alpha); + + let query_combined = combine_columns(builder, &self.query_results, &randomness); + let one = builder.constant(1); + let v_query = logup_poly_val( + builder, + &query_combined, + &vec![one; query_combined.len()], + &alpha, + ); + + assert_eq_rational(builder, &v_table, &v_query); + } } pub struct LogUpRangeProofTable { @@ -455,6 +484,25 @@ impl LogUpRangeProofTable { ); assert_eq_rational(builder, &v_table, &v_query); } + + pub fn final_check_with_query_count>( + &mut self, + builder: &mut B, + query_count: &[Variable], + ) { + let alpha = builder.get_random_value(); + + let v_table = logup_poly_val(builder, &self.table_keys, query_count, &alpha); + + let one = builder.constant(1); + let v_query = logup_poly_val( + builder, + &self.query_keys, + &vec![one; self.query_keys.len()], + &alpha, + ); + assert_eq_rational(builder, &v_table, &v_query); + } } pub fn query_count_hint(inputs: &[F], outputs: &mut [F]) -> Result<(), Error> { diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index 40081f72..a929ffac 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -888,6 +888,17 @@ impl>> Context { ContextState::WitnessDone, "Please finish computation graph and witness solving before exporting device memories." ); + self.export_device_memories_impl() + } + + /// Export device memories without checking the context state. + /// Use this when you need to export memories outside the normal workflow, + /// e.g., for memory optimization where you want to export and then drop the context. + pub fn export_device_memories_unchecked(&self) -> Vec>> { + self.export_device_memories_impl() + } + + fn export_device_memories_impl(&self) -> Vec>> { self.device_memories .iter() .map(|dm| { diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs index 7d7fed98..6a559fa1 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs @@ -73,3 +73,15 @@ where wait_async(ClientHttpHelper::request_exit()) } } + +impl ExpanderNoOverSubscribe +where + as ExpanderPCS>>::Commitment: + AsRef< as ExpanderPCS>>::Commitment>, +{ + /// Lightweight prove that doesn't require computation_graph or prover_setup. + /// Use this after setup() to allow releasing those large data structures before proving. + pub fn prove_lightweight(device_memories: Vec>>) { + client_send_witness_and_prove::(device_memories); + } +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs index 42315b39..bf4a07c4 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs @@ -112,7 +112,11 @@ where let mpi_size = if allow_oversubscribe { max_parallel_count } else { - let num_cpus = prev_power_of_two(num_cpus::get_physical()); + let num_cpus = std::env::var("ZKML_NUM_CPUS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or_else(num_cpus::get_physical); + let num_cpus = prev_power_of_two(num_cpus); if max_parallel_count > num_cpus { num_cpus } else {