-
Notifications
You must be signed in to change notification settings - Fork 37
Hc/broadcast #190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Hc/broadcast #190
Changes from 4 commits
8da60af
307634d
c356ea6
1f643ca
93b9c3f
2e07c35
fa0438f
51d3dca
1efff25
71d85d7
06f8620
4bb0e0e
168dcff
70acfca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,3 +1,5 @@ | ||||||
| use core::num; | ||||||
|
|
||||||
| use arith::SimdField; | ||||||
| use rayon::iter::{IntoParallelIterator, ParallelIterator}; | ||||||
| use serdes::ExpSerde; | ||||||
|
|
@@ -44,7 +46,7 @@ pub struct KernelCall { | |||||
| num_parallel: usize, | ||||||
| input_handles: Vec<DeviceMemoryHandle>, | ||||||
| output_handles: Vec<DeviceMemoryHandle>, | ||||||
| is_broadcast: Vec<bool>, | ||||||
| is_broadcast: Vec<usize>, | ||||||
| } | ||||||
|
|
||||||
| #[derive(PartialEq, Eq, Clone, Debug, ExpSerde)] | ||||||
|
|
@@ -53,7 +55,7 @@ pub struct ProofTemplate { | |||||
| pub commitment_indices: Vec<usize>, | ||||||
| pub commitment_bit_orders: Vec<BitOrder>, | ||||||
| pub parallel_count: usize, | ||||||
| pub is_broadcast: Vec<bool>, | ||||||
| pub is_broadcast: Vec<usize>, | ||||||
| } | ||||||
|
|
||||||
| impl ProofTemplate { | ||||||
|
|
@@ -69,7 +71,7 @@ impl ProofTemplate { | |||||
| pub fn parallel_count(&self) -> usize { | ||||||
| self.parallel_count | ||||||
| } | ||||||
| pub fn is_broadcast(&self) -> &[bool] { | ||||||
| pub fn is_broadcast(&self) -> &[usize] { | ||||||
| &self.is_broadcast | ||||||
| } | ||||||
| } | ||||||
|
|
@@ -156,17 +158,19 @@ fn check_shape_compat( | |||||
| kernel_shape: &Shape, | ||||||
| io_shape: &Shape, | ||||||
| parallel_count: usize, | ||||||
| ) -> Option<bool> { | ||||||
| ) -> Option<usize> { | ||||||
| if kernel_shape.len() == io_shape.len() { | ||||||
| if *kernel_shape == *io_shape { | ||||||
| Some(true) | ||||||
| Some(parallel_count) | ||||||
| } else { | ||||||
| None | ||||||
| } | ||||||
| } else if kernel_shape.len() + 1 == io_shape.len() { | ||||||
| if io_shape.iter().skip(1).eq(kernel_shape.iter()) { | ||||||
| if io_shape[0] == parallel_count { | ||||||
| Some(false) | ||||||
| Some(1) | ||||||
| } else if parallel_count % io_shape[0] == 0 { | ||||||
| Some(parallel_count / io_shape[0]) | ||||||
| } else { | ||||||
| None | ||||||
| } | ||||||
|
|
@@ -299,18 +303,15 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| &self, | ||||||
| values: &[SIMDField<C>], | ||||||
| s: &mut [SIMDField<C>], | ||||||
| is_broadcast: bool, | ||||||
| is_broadcast: usize, | ||||||
| parallel_index: usize, | ||||||
| chunk_size: Option<usize>, | ||||||
| ) { | ||||||
| if is_broadcast { | ||||||
| s.copy_from_slice(values); | ||||||
| } else { | ||||||
| let chunk_size = chunk_size.unwrap(); | ||||||
| s.copy_from_slice( | ||||||
| &values[chunk_size * parallel_index..chunk_size * (parallel_index + 1)], | ||||||
| ); | ||||||
| } | ||||||
| let chunk_size = chunk_size.unwrap(); | ||||||
| let start_index = chunk_size * parallel_index % values.len(); | ||||||
| s.copy_from_slice( | ||||||
| &values[start_index..(start_index + chunk_size)], | ||||||
| ); | ||||||
| } | ||||||
|
|
||||||
| pub fn call_kernel( | ||||||
|
|
@@ -332,7 +333,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| .enumerate() | ||||||
| { | ||||||
| if !spec.is_input { | ||||||
| is_broadcast.push(false); | ||||||
| is_broadcast.push(1); | ||||||
| continue; | ||||||
| } | ||||||
| /*println!( | ||||||
|
|
@@ -350,7 +351,8 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| .as_ref() | ||||||
| .unwrap() | ||||||
| .shape_history | ||||||
| .get_initial_split_list(!ib); | ||||||
| .get_initial_split_list(ib/num_parallel+1); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The expression
Suggested change
|
||||||
| // let isl = vec![1,64,4096]; | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| let t = io.as_ref().unwrap().id; | ||||||
| self.device_memories[t].required_shape_products = merge_shape_products( | ||||||
| &isl, | ||||||
|
|
@@ -371,7 +373,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| } | ||||||
| } | ||||||
| for (io_spec, ib) in kernel.io_specs().iter().zip(is_broadcast.iter()) { | ||||||
| if io_spec.is_output && *ib { | ||||||
| if io_spec.is_output && *ib!=1 { | ||||||
| panic!("Output is broadcasted, but it shouldn't be"); | ||||||
| } | ||||||
| } | ||||||
|
|
@@ -381,11 +383,12 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| let mut outputs_tmp = vec![Vec::new(); kernel.io_specs().len()]; | ||||||
| let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()]; | ||||||
| let mut chunk_sizes: Vec<Option<usize>> = vec![None; kernel.io_specs().len()]; | ||||||
| for (((input, &ib), ir_inputs), chunk_size) in ios | ||||||
| for ((((input, &ib), ir_inputs), chunk_size), kernel_shape) in ios | ||||||
| .iter() | ||||||
| .zip(is_broadcast.iter()) | ||||||
| .zip(ir_inputs_all.iter_mut()) | ||||||
| .zip(chunk_sizes.iter_mut()) | ||||||
| .zip(kernel.io_shapes().iter()) | ||||||
| { | ||||||
| if input.is_none() { | ||||||
| continue; | ||||||
|
|
@@ -394,9 +397,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| let values = handle | ||||||
| .shape_history | ||||||
| .permute_vec(&self.device_memories[handle.id].values); | ||||||
| if !ib { | ||||||
| *chunk_size = Some(values.len() / num_parallel); | ||||||
| } | ||||||
| *chunk_size = Some(kernel_shape.iter().product()); | ||||||
| *ir_inputs = values; | ||||||
| } | ||||||
| let mut ir_inputs_per_parallel = Vec::new(); | ||||||
|
|
@@ -469,7 +470,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| .as_ref() | ||||||
| .unwrap() | ||||||
| .shape_history | ||||||
| .get_initial_split_list(true), | ||||||
| .get_initial_split_list(1), | ||||||
| &self.device_memories[id].required_shape_products, | ||||||
| ); | ||||||
| *output = handle.clone(); | ||||||
|
|
@@ -513,7 +514,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| .zip(kernel_call.input_handles.iter()) | ||||||
| .zip(kernel_call.is_broadcast.iter()) | ||||||
| { | ||||||
| if !spec.is_input || ib { | ||||||
| if !spec.is_input || ib > 1 { | ||||||
| continue; | ||||||
| } | ||||||
| let pad_shape = get_pad_shape(input_handle).unwrap(); | ||||||
|
|
@@ -526,7 +527,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| .zip(kernel_call.output_handles.iter()) | ||||||
| .zip(kernel_call.is_broadcast.iter()) | ||||||
| { | ||||||
| if !spec.is_output || ib { | ||||||
| if !spec.is_output || ib > 1 { | ||||||
| continue; | ||||||
| } | ||||||
| let pad_shape = get_pad_shape(output_handle).unwrap(); | ||||||
|
|
@@ -549,7 +550,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| if x != 1 && x != kernel_call.num_parallel { | ||||||
| let sh_tmp = handle.shape_history.reshape(&[x, total / x]); | ||||||
| dm.required_shape_products = merge_shape_products( | ||||||
| &sh_tmp.get_initial_split_list(true), | ||||||
| &sh_tmp.get_initial_split_list(1), | ||||||
| &dm.required_shape_products, | ||||||
| ); | ||||||
| } | ||||||
|
|
@@ -576,7 +577,6 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| self.state = ContextState::ComputationGraphDone; | ||||||
|
|
||||||
| let dm_shapes = self.propagate_and_get_shapes(); | ||||||
|
|
||||||
| let (mut cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg { | ||||||
| for (i, kernel) in cg.kernels.iter().enumerate() { | ||||||
| assert_eq!(self.kernels.add(kernel), i); | ||||||
|
|
@@ -622,11 +622,11 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| let mut psi = Vec::new(); | ||||||
| for (s, &ib) in pad_shapes_input.iter().zip(kernel_call.is_broadcast.iter()) { | ||||||
| psi.push(s.as_ref().map(|t| { | ||||||
| if ib { | ||||||
| if ib == kernel_call.num_parallel { | ||||||
| t.0.clone() | ||||||
| } else { | ||||||
| keep_shape_since(&t.0, kernel_call.num_parallel) | ||||||
| } | ||||||
| } else{ | ||||||
| keep_shape_since(&t.0, kernel_call.num_parallel/ib) | ||||||
| } | ||||||
| })); | ||||||
| } | ||||||
| let mut pso = Vec::new(); | ||||||
|
|
@@ -635,7 +635,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| .zip(kernel_call.is_broadcast.iter()) | ||||||
| { | ||||||
| pso.push(s.as_ref().map(|t| { | ||||||
| if ib { | ||||||
| if ib == kernel_call.num_parallel { | ||||||
| t.0.clone() | ||||||
| } else { | ||||||
| keep_shape_since(&t.0, kernel_call.num_parallel) | ||||||
|
|
@@ -661,7 +661,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| commitment_indices.push(handle.as_ref().unwrap().id); | ||||||
| commitment_bit_orders.push(shape.1.clone()); | ||||||
| is_broadcast.push(ib); | ||||||
| if !ib { | ||||||
| if ib == 1 { | ||||||
| any_shape = Some(shape.0.clone()); | ||||||
| } | ||||||
| } | ||||||
|
|
@@ -678,7 +678,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| commitment_indices.push(handle.as_ref().unwrap().id); | ||||||
| commitment_bit_orders.push(shape.1.clone()); | ||||||
| is_broadcast.push(ib); | ||||||
| if !ib { | ||||||
| if ib == 1 { | ||||||
| any_shape = Some(shape.0.clone()); | ||||||
| } | ||||||
| } | ||||||
|
|
@@ -695,7 +695,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| dm_max += 1; | ||||||
| commitment_bit_orders.push((0..n.trailing_zeros() as usize).collect()); | ||||||
| commitments_lens.push(n); | ||||||
| is_broadcast.push(false); | ||||||
| is_broadcast.push(1); | ||||||
| } | ||||||
|
|
||||||
| let kernel_id = self.kernels.add(&kernel); | ||||||
|
|
@@ -778,9 +778,8 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| let values = handle | ||||||
| .shape_history | ||||||
| .permute_vec(&self.device_memories[handle.id].values); | ||||||
| if !ib { | ||||||
| *chunk_size = Some(values.len() / kernel_call.num_parallel); | ||||||
| } | ||||||
| let kernel_shape = handle.shape_history.shape(); | ||||||
| *chunk_size = Some(kernel_shape.iter().product()); | ||||||
| *ir_inputs = values; | ||||||
| } | ||||||
| for (((output, &ib), ir_inputs), chunk_size) in kernel_call | ||||||
|
|
@@ -800,7 +799,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| let values = handle | ||||||
| .shape_history | ||||||
| .permute_vec(&self.device_memories[handle.id].values); | ||||||
| assert!(!ib); | ||||||
| assert!(ib == 1); | ||||||
| *chunk_size = Some(values.len() / kernel_call.num_parallel); | ||||||
| *ir_inputs = values; | ||||||
| } | ||||||
|
|
@@ -823,7 +822,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| self.ir_copy_from_device_memory( | ||||||
| ir_inputs, | ||||||
| &mut inputs[*input_start..*input_end], | ||||||
| chunk_size.is_none(), | ||||||
| chunk_size.unwrap_or(2), | ||||||
| parallel_i, | ||||||
| *chunk_size, | ||||||
| ); | ||||||
|
|
@@ -843,7 +842,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| self.ir_copy_from_device_memory( | ||||||
| ir_outputs, | ||||||
| &mut inputs[*output_start..*output_end], | ||||||
| chunk_size.is_none(), | ||||||
| chunk_size.unwrap_or(1), | ||||||
| parallel_i, | ||||||
| *chunk_size, | ||||||
| ); | ||||||
|
|
@@ -891,6 +890,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||
| .map(|dm| { | ||||||
| let shape = prefix_products_to_shape(&dm.required_shape_products); | ||||||
| let im = shape_padded_mapping(&shape); | ||||||
| let tmp = im.map_inputs(&dm.values); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| im.map_inputs(&dm.values) | ||||||
| }) | ||||||
| .collect() | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -301,9 +301,8 @@ fn reorder_ir_inputs<C: Config>( | |||||
| lc_in[i].len = n; | ||||||
| assert!(var_max % n == 0); | ||||||
| let im = shape_padded_mapping(&pad_shapes[i]); | ||||||
| // println!("{:?}", im.mapping()); | ||||||
| for (j, &k) in im.mapping().iter().enumerate() { | ||||||
| var_new_id[prev + k + 1] = var_max + j + 1; | ||||||
| var_new_id[prev + j + 1] = var_max + k + 1; | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line appears to have incorrect logic for remapping variable IDs. The current implementation The logic should be
Suggested change
|
||||||
| } | ||||||
| var_max += n; | ||||||
| } | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
is_broadcastparameter is unused withinir_copy_from_device_memoryand should be removed to improve code clarity. The function signature and all call sites (e.g., lines 418, 825, and 845) should be updated accordingly.