-
Notifications
You must be signed in to change notification settings - Fork 37
Zq/detach compile #179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Zq/detach compile #179
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,13 +26,13 @@ use super::{ | |||||||||||||||||
| pub use macros::call_kernel; | ||||||||||||||||||
|
|
||||||||||||||||||
| struct DeviceMemory<C: Config> { | ||||||||||||||||||
| values: Vec<SIMDField<C>>, | ||||||||||||||||||
| pub values: Vec<SIMDField<C>>, | ||||||||||||||||||
| required_shape_products: Vec<usize>, | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| #[derive(Clone, Debug, ExpSerde)] | ||||||||||||||||||
| pub struct DeviceMemoryHandleRaw { | ||||||||||||||||||
| id: usize, | ||||||||||||||||||
| pub id: usize, | ||||||||||||||||||
| shape_history: ShapeHistory, | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -217,13 +217,29 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| pub fn new_device_memory(&mut self, shape: Shape) -> (DeviceMemoryHandle, usize) { | ||||||||||||||||||
| let t = shape_vec_len(&shape); | ||||||||||||||||||
| let required_shape_products = if t == 1 { vec![1] } else { vec![1, t] }; | ||||||||||||||||||
| self.device_memories.push(DeviceMemory { | ||||||||||||||||||
| values: vec![], | ||||||||||||||||||
| required_shape_products, | ||||||||||||||||||
| }); | ||||||||||||||||||
| (Some(DeviceMemoryHandleRaw { | ||||||||||||||||||
| id: self.device_memories.len() - 1, | ||||||||||||||||||
| shape_history: ShapeHistory::new(shape), | ||||||||||||||||||
| }), self.device_memories.len() - 1) | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| pub fn copy_to_device<T: VecShaped<CircuitField<C>>>( | ||||||||||||||||||
| &mut self, | ||||||||||||||||||
| host_memory: &T, | ||||||||||||||||||
| ) -> DeviceMemoryHandle { | ||||||||||||||||||
| device_memory_id: usize, | ||||||||||||||||||
| ) { | ||||||||||||||||||
| assert!(device_memory_id < self.device_memories.len(), "The device memory dosen't exist."); | ||||||||||||||||||
| let (flat, shape) = flatten_shaped(host_memory); | ||||||||||||||||||
| assert_eq!(shape_vec_len(&shape), shape_vec_len(&self.device_memories[device_memory_id].required_shape_products), "The len of values doesn't match."); | ||||||||||||||||||
| let simd_flat = pack_vec::<C>(&flat); | ||||||||||||||||||
| make_device_mem(&mut self.device_memories, simd_flat, shape) | ||||||||||||||||||
| self.device_memories[device_memory_id].values = simd_flat; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| pub fn copy_to_device_and_pack_simd<T: VecShaped<CircuitField<C>>>( | ||||||||||||||||||
|
|
@@ -367,72 +383,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||||||||||||||
|
|
||||||||||||||||||
| let kernel_id = self.kernel_primitives.add(kernel); | ||||||||||||||||||
|
|
||||||||||||||||||
| let mut outputs_tmp = vec![Vec::new(); kernel.io_specs().len()]; | ||||||||||||||||||
| let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()]; | ||||||||||||||||||
| let mut chunk_sizes: Vec<Option<usize>> = vec![None; kernel.io_specs().len()]; | ||||||||||||||||||
| for (((input, &ib), ir_inputs), chunk_size) in ios | ||||||||||||||||||
| .iter() | ||||||||||||||||||
| .zip(is_broadcast.iter()) | ||||||||||||||||||
| .zip(ir_inputs_all.iter_mut()) | ||||||||||||||||||
| .zip(chunk_sizes.iter_mut()) | ||||||||||||||||||
| { | ||||||||||||||||||
| if input.is_none() { | ||||||||||||||||||
| continue; | ||||||||||||||||||
| } | ||||||||||||||||||
| let handle = ensure_handle(input.clone()); | ||||||||||||||||||
| let values = handle | ||||||||||||||||||
| .shape_history | ||||||||||||||||||
| .permute_vec(&self.device_memories[handle.id].values); | ||||||||||||||||||
| if !ib { | ||||||||||||||||||
| *chunk_size = Some(values.len() / num_parallel); | ||||||||||||||||||
| } | ||||||||||||||||||
| *ir_inputs = values; | ||||||||||||||||||
| } | ||||||||||||||||||
| let mut ir_inputs_per_parallel = Vec::new(); | ||||||||||||||||||
| for parallel_i in 0..num_parallel { | ||||||||||||||||||
| let mut ir_inputs = vec![SIMDField::<C>::zero(); kernel.ir_for_calling().input_size()]; | ||||||||||||||||||
| for (i, ((input, input_start), input_end)) in ios | ||||||||||||||||||
| .iter() | ||||||||||||||||||
| .zip(kernel.ir_input_offsets().iter()) | ||||||||||||||||||
| .zip(kernel.ir_input_offsets().iter().skip(1)) | ||||||||||||||||||
| .enumerate() | ||||||||||||||||||
| { | ||||||||||||||||||
| if input.is_none() { | ||||||||||||||||||
| continue; | ||||||||||||||||||
| } | ||||||||||||||||||
| self.ir_copy_from_device_memory( | ||||||||||||||||||
| &ir_inputs_all[i], | ||||||||||||||||||
| &mut ir_inputs[*input_start..*input_end], | ||||||||||||||||||
| is_broadcast[i], | ||||||||||||||||||
| parallel_i, | ||||||||||||||||||
| chunk_sizes[i], | ||||||||||||||||||
| ); | ||||||||||||||||||
| } | ||||||||||||||||||
| ir_inputs_per_parallel.push(ir_inputs); | ||||||||||||||||||
| } | ||||||||||||||||||
| let ir_outputs_per_parallel: Vec<Result<Vec<SIMDField<C>>, Error>> = ir_inputs_per_parallel | ||||||||||||||||||
| .into_par_iter() | ||||||||||||||||||
| .map(|ir_inputs| { | ||||||||||||||||||
| kernel | ||||||||||||||||||
| .ir_for_calling() | ||||||||||||||||||
| .eval_safe_simd(ir_inputs, &[], &self.hint_caller) | ||||||||||||||||||
| }) | ||||||||||||||||||
| .collect(); | ||||||||||||||||||
| for ir_outputs in ir_outputs_per_parallel { | ||||||||||||||||||
| let ir_outputs = ir_outputs?; | ||||||||||||||||||
| for (((spec, output_start), output_end), out) in kernel | ||||||||||||||||||
| .io_specs() | ||||||||||||||||||
| .iter() | ||||||||||||||||||
| .zip(kernel.ir_output_offsets().iter()) | ||||||||||||||||||
| .zip(kernel.ir_output_offsets().iter().skip(1)) | ||||||||||||||||||
| .zip(outputs_tmp.iter_mut()) | ||||||||||||||||||
| { | ||||||||||||||||||
| if !spec.is_output { | ||||||||||||||||||
| continue; | ||||||||||||||||||
| } | ||||||||||||||||||
| out.extend_from_slice(&ir_outputs[*output_start..*output_end]); | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| let mut outputs_tmp: Vec<Vec<SIMDField::<C>>> = vec![Vec::new(); kernel.io_specs().len()]; | ||||||||||||||||||
| let input_handles = ios.to_vec(); | ||||||||||||||||||
| let mut output_handles = vec![None; kernel.io_specs().len()]; | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -447,12 +398,13 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||||||||||||||
| *output = None; | ||||||||||||||||||
| continue; | ||||||||||||||||||
| } | ||||||||||||||||||
| let handle = make_device_mem( | ||||||||||||||||||
| &mut self.device_memories, | ||||||||||||||||||
| ov, | ||||||||||||||||||
| shape_prepend(shape, num_parallel), | ||||||||||||||||||
| ); | ||||||||||||||||||
| let id = handle.as_ref().unwrap().id; | ||||||||||||||||||
| // let handle = make_device_mem( | ||||||||||||||||||
| // &mut self.device_memories, | ||||||||||||||||||
| // ov, | ||||||||||||||||||
| // shape_prepend(shape, num_parallel), | ||||||||||||||||||
| // ); | ||||||||||||||||||
| // let id = handle.as_ref().unwrap().id; | ||||||||||||||||||
| let (handle, id) = self.new_device_memory(shape_prepend(shape, num_parallel)); | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block of commented-out code should be removed to improve code clarity and maintainability.
Suggested change
|
||||||||||||||||||
| self.device_memories[id].required_shape_products = merge_shape_products( | ||||||||||||||||||
| &handle | ||||||||||||||||||
| .as_ref() | ||||||||||||||||||
|
|
@@ -732,6 +684,92 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| self.state = ContextState::WitnessDone; | ||||||||||||||||||
|
|
||||||||||||||||||
| for kernel_call in self.kernel_calls.iter() { | ||||||||||||||||||
| let kernel = self.kernel_primitives.get(kernel_call.kernel_id); | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||
| let num_parallel = kernel_call.num_parallel; | ||||||||||||||||||
| let is_broadcast = &kernel_call.is_broadcast; | ||||||||||||||||||
|
|
||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||
| let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()]; | ||||||||||||||||||
| let mut chunk_sizes: Vec<Option<usize>> = vec![None; kernel.io_specs().len()]; | ||||||||||||||||||
| for (((input, &ib), ir_inputs), chunk_size) in kernel_call.input_handles | ||||||||||||||||||
| .iter() | ||||||||||||||||||
| .zip(is_broadcast.iter()) | ||||||||||||||||||
| .zip(ir_inputs_all.iter_mut()) | ||||||||||||||||||
| .zip(chunk_sizes.iter_mut()) | ||||||||||||||||||
| { | ||||||||||||||||||
| if input.is_none() { | ||||||||||||||||||
| continue; | ||||||||||||||||||
| } | ||||||||||||||||||
| let handle = ensure_handle(input.clone()); | ||||||||||||||||||
| let values = handle | ||||||||||||||||||
| .shape_history | ||||||||||||||||||
| .permute_vec(&self.device_memories[handle.id].values); | ||||||||||||||||||
| if !ib { | ||||||||||||||||||
| *chunk_size = Some(values.len() / num_parallel); | ||||||||||||||||||
| } | ||||||||||||||||||
| *ir_inputs = values; | ||||||||||||||||||
| } | ||||||||||||||||||
| let mut ir_inputs_per_parallel = Vec::new(); | ||||||||||||||||||
| for parallel_i in 0..num_parallel { | ||||||||||||||||||
| let mut ir_inputs = vec![SIMDField::<C>::zero(); kernel.ir_for_calling().input_size()]; | ||||||||||||||||||
| for (i, ((input, input_start), input_end)) in kernel_call.input_handles | ||||||||||||||||||
| .iter() | ||||||||||||||||||
| .zip(kernel.ir_input_offsets().iter()) | ||||||||||||||||||
| .zip(kernel.ir_input_offsets().iter().skip(1)) | ||||||||||||||||||
| .enumerate() | ||||||||||||||||||
| { | ||||||||||||||||||
| if input.is_none() { | ||||||||||||||||||
| continue; | ||||||||||||||||||
| } | ||||||||||||||||||
| self.ir_copy_from_device_memory( | ||||||||||||||||||
| &ir_inputs_all[i], | ||||||||||||||||||
| &mut ir_inputs[*input_start..*input_end], | ||||||||||||||||||
| is_broadcast[i], | ||||||||||||||||||
| parallel_i, | ||||||||||||||||||
| chunk_sizes[i], | ||||||||||||||||||
| ); | ||||||||||||||||||
| } | ||||||||||||||||||
| ir_inputs_per_parallel.push(ir_inputs); | ||||||||||||||||||
| } | ||||||||||||||||||
| let ir_outputs_per_parallel: Vec<Result<Vec<SIMDField<C>>, Error>> = ir_inputs_per_parallel | ||||||||||||||||||
| .into_par_iter() | ||||||||||||||||||
| .map(|ir_inputs| { | ||||||||||||||||||
| kernel | ||||||||||||||||||
| .ir_for_calling() | ||||||||||||||||||
| .eval_safe_simd(ir_inputs, &[], &self.hint_caller) | ||||||||||||||||||
| }) | ||||||||||||||||||
| .collect(); | ||||||||||||||||||
|
|
||||||||||||||||||
| let mut outputs_tmp: Vec<Vec<SIMDField::<C>>> = vec![Vec::new(); kernel.io_specs().len()]; | ||||||||||||||||||
| for ir_outputs in ir_outputs_per_parallel { | ||||||||||||||||||
| let ir_outputs = ir_outputs?; | ||||||||||||||||||
| for (((spec, output_start), output_end), out) in kernel | ||||||||||||||||||
| .io_specs() | ||||||||||||||||||
| .iter() | ||||||||||||||||||
| .zip(kernel.ir_output_offsets().iter()) | ||||||||||||||||||
| .zip(kernel.ir_output_offsets().iter().skip(1)) | ||||||||||||||||||
| .zip(outputs_tmp.iter_mut()) | ||||||||||||||||||
| { | ||||||||||||||||||
| if !spec.is_output { | ||||||||||||||||||
| continue; | ||||||||||||||||||
| } | ||||||||||||||||||
| out.extend_from_slice(&ir_outputs[*output_start..*output_end]); | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| for ((output, spec), ov) in kernel_call.output_handles | ||||||||||||||||||
| .iter() | ||||||||||||||||||
| .zip(kernel.io_specs().iter()) | ||||||||||||||||||
| .zip(outputs_tmp.into_iter()) | ||||||||||||||||||
| { | ||||||||||||||||||
| if !spec.is_output { | ||||||||||||||||||
| continue; | ||||||||||||||||||
| } | ||||||||||||||||||
| let output_id = output.as_ref().unwrap().id; | ||||||||||||||||||
| self.device_memories[output_id].values = ov; | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
Comment on lines
+688
to
+759
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code for kernel execution, which was moved into Specifically, the preparation of inputs for parallel execution is duplicated. To improve maintainability and reduce code duplication, consider refactoring the common logic for preparing inputs into a separate helper function. This would make the |
||||||||||||||||||
|
|
||||||||||||||||||
| for (kernel_call, proof_template) in | ||||||||||||||||||
| self.kernel_calls.iter().zip(self.proof_templates.iter()) | ||||||||||||||||||
|
|
||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -327,7 +327,8 @@ fn get_computation_graph() -> ComputationGraph<M31Config> { | |
| } | ||
|
|
||
| println!("prepare data ok"); | ||
| let p = ctx.copy_to_device(&p); | ||
| let p_value = p; | ||
| let (p, p_id) = ctx.new_device_memory(vec![N_PARALLEL, 64 * 8]); | ||
| println!("copy to device ok"); | ||
| let mut out = None; | ||
| call_kernel!(ctx, kernel, N_PARALLEL, p, mut out).unwrap(); | ||
|
|
@@ -338,6 +339,7 @@ fn get_computation_graph() -> ComputationGraph<M31Config> { | |
| assert_eq!(out[0][0], expected_res[0][0]); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| let computation_graph = ctx.compile_computation_graph().unwrap(); | ||
| ctx.copy_to_device(&p_value, p_id); | ||
|
|
||
| computation_graph | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.