Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions circuit-std-rs/tests/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,14 @@ fn rangeproof_zkcuda_test() {
let kernel: KernelPrimitive<M31Config> = compile_rangeproof_test_kernel().unwrap();
let mut ctx: Context<M31Config, _> = Context::new(hint_registry);

let a = M31::from(1 << 9);
let a = ctx.copy_to_device(&a);
let a_value = M31::from(1 << 9);
let (a, a_id) = ctx.new_device_memory(vec![]);
let a = a.reshape(&[1]);
call_kernel!(ctx, kernel, 1, a).unwrap();

type P = Expander<M31Config>;
let computation_graph = ctx.compile_computation_graph().unwrap();
ctx.copy_to_device(&a_value, a_id);
ctx.solve_witness().unwrap();
let (prover_setup, verifier_setup) = <P as ProvingSystem<M31Config>>::setup(&computation_graph);
let proof = P::prove(
Expand All @@ -180,13 +181,14 @@ fn rangeproof_zkcuda_test_fail() {
let kernel: KernelPrimitive<M31Config> = compile_rangeproof_test_kernel().unwrap();
let mut ctx: Context<M31Config, _> = Context::new(hint_registry);

let a = M31::from(1 << 11);
let a = ctx.copy_to_device(&a);
let a_value = M31::from(1 << 11);
let (a, a_id) = ctx.new_device_memory(vec![]);
let a = a.reshape(&[1]);
call_kernel!(ctx, kernel, 1, a).unwrap();

type P = Expander<M31Config>;
let computation_graph = ctx.compile_computation_graph().unwrap();
ctx.copy_to_device(&a_value, a_id);
ctx.solve_witness().unwrap();
let (prover_setup, verifier_setup) = <P as ProvingSystem<M31Config>>::setup(&computation_graph);
let proof = P::prove(
Expand Down
6 changes: 4 additions & 2 deletions expander_compiler/bin/zkcuda_matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ pub fn zkcuda_matmul<C: Config, P: ProvingSystem<C>, const N: usize>() {
}
}

let a = ctx.copy_to_device(&mat_a);
let b = ctx.copy_to_device(&mat_b);
let (a, a_id) = ctx.new_device_memory(vec![N, M]);
let (b, b_id) = ctx.new_device_memory(vec![M, K]);
let mut c = None;
call_kernel!(ctx, kernel_mul_line, N, a, b, mut c).unwrap();

Expand All @@ -72,6 +72,8 @@ pub fn zkcuda_matmul<C: Config, P: ProvingSystem<C>, const N: usize>() {
assert_eq!(result, expected_result);

let computation_graph = ctx.compile_computation_graph().unwrap();
ctx.copy_to_device(&mat_a, a_id);
ctx.copy_to_device(&mat_b, b_id);
ctx.solve_witness().unwrap();

let (prover_setup, verifier_setup) = P::setup(&computation_graph);
Expand Down
190 changes: 114 additions & 76 deletions expander_compiler/src/zkcuda/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ use super::{
pub use macros::call_kernel;

struct DeviceMemory<C: Config> {
values: Vec<SIMDField<C>>,
pub values: Vec<SIMDField<C>>,
required_shape_products: Vec<usize>,
}

#[derive(Clone, Debug, ExpSerde)]
pub struct DeviceMemoryHandleRaw {
id: usize,
pub id: usize,
shape_history: ShapeHistory,
}

Expand Down Expand Up @@ -217,13 +217,29 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
}
}

pub fn new_device_memory(&mut self, shape: Shape) -> (DeviceMemoryHandle, usize) {
let t = shape_vec_len(&shape);
let required_shape_products = if t == 1 { vec![1] } else { vec![1, t] };
self.device_memories.push(DeviceMemory {
values: vec![],
required_shape_products,
});
(Some(DeviceMemoryHandleRaw {
id: self.device_memories.len() - 1,
shape_history: ShapeHistory::new(shape),
}), self.device_memories.len() - 1)
}

pub fn copy_to_device<T: VecShaped<CircuitField<C>>>(
&mut self,
host_memory: &T,
) -> DeviceMemoryHandle {
device_memory_id: usize,
) {
assert!(device_memory_id < self.device_memories.len(), "The device memory dosen't exist.");
Comment thread
chonpsk marked this conversation as resolved.
Outdated
let (flat, shape) = flatten_shaped(host_memory);
assert_eq!(shape_vec_len(&shape), shape_vec_len(&self.device_memories[device_memory_id].required_shape_products), "The len of values doesn't match.");
let simd_flat = pack_vec::<C>(&flat);
make_device_mem(&mut self.device_memories, simd_flat, shape)
self.device_memories[device_memory_id].values = simd_flat;
}

pub fn copy_to_device_and_pack_simd<T: VecShaped<CircuitField<C>>>(
Expand Down Expand Up @@ -367,72 +383,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {

let kernel_id = self.kernel_primitives.add(kernel);

let mut outputs_tmp = vec![Vec::new(); kernel.io_specs().len()];
let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()];
let mut chunk_sizes: Vec<Option<usize>> = vec![None; kernel.io_specs().len()];
for (((input, &ib), ir_inputs), chunk_size) in ios
.iter()
.zip(is_broadcast.iter())
.zip(ir_inputs_all.iter_mut())
.zip(chunk_sizes.iter_mut())
{
if input.is_none() {
continue;
}
let handle = ensure_handle(input.clone());
let values = handle
.shape_history
.permute_vec(&self.device_memories[handle.id].values);
if !ib {
*chunk_size = Some(values.len() / num_parallel);
}
*ir_inputs = values;
}
let mut ir_inputs_per_parallel = Vec::new();
for parallel_i in 0..num_parallel {
let mut ir_inputs = vec![SIMDField::<C>::zero(); kernel.ir_for_calling().input_size()];
for (i, ((input, input_start), input_end)) in ios
.iter()
.zip(kernel.ir_input_offsets().iter())
.zip(kernel.ir_input_offsets().iter().skip(1))
.enumerate()
{
if input.is_none() {
continue;
}
self.ir_copy_from_device_memory(
&ir_inputs_all[i],
&mut ir_inputs[*input_start..*input_end],
is_broadcast[i],
parallel_i,
chunk_sizes[i],
);
}
ir_inputs_per_parallel.push(ir_inputs);
}
let ir_outputs_per_parallel: Vec<Result<Vec<SIMDField<C>>, Error>> = ir_inputs_per_parallel
.into_par_iter()
.map(|ir_inputs| {
kernel
.ir_for_calling()
.eval_safe_simd(ir_inputs, &[], &self.hint_caller)
})
.collect();
for ir_outputs in ir_outputs_per_parallel {
let ir_outputs = ir_outputs?;
for (((spec, output_start), output_end), out) in kernel
.io_specs()
.iter()
.zip(kernel.ir_output_offsets().iter())
.zip(kernel.ir_output_offsets().iter().skip(1))
.zip(outputs_tmp.iter_mut())
{
if !spec.is_output {
continue;
}
out.extend_from_slice(&ir_outputs[*output_start..*output_end]);
}
}
let mut outputs_tmp: Vec<Vec<SIMDField::<C>>> = vec![Vec::new(); kernel.io_specs().len()];
let input_handles = ios.to_vec();
let mut output_handles = vec![None; kernel.io_specs().len()];

Expand All @@ -447,12 +398,13 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
*output = None;
continue;
}
let handle = make_device_mem(
&mut self.device_memories,
ov,
shape_prepend(shape, num_parallel),
);
let id = handle.as_ref().unwrap().id;
// let handle = make_device_mem(
// &mut self.device_memories,
// ov,
// shape_prepend(shape, num_parallel),
// );
// let id = handle.as_ref().unwrap().id;
let (handle, id) = self.new_device_memory(shape_prepend(shape, num_parallel));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of commented-out code should be removed to improve code clarity and maintainability.

Suggested change
// let handle = make_device_mem(
// &mut self.device_memories,
// ov,
// shape_prepend(shape, num_parallel),
// );
// let id = handle.as_ref().unwrap().id;
let (handle, id) = self.new_device_memory(shape_prepend(shape, num_parallel));
let (handle, id) = self.new_device_memory(shape_prepend(shape, num_parallel));

self.device_memories[id].required_shape_products = merge_shape_products(
&handle
.as_ref()
Expand Down Expand Up @@ -732,6 +684,92 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
}
}
self.state = ContextState::WitnessDone;

for kernel_call in self.kernel_calls.iter() {
let kernel = self.kernel_primitives.get(kernel_call.kernel_id);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a minor formatting issue here with an extra space before self. Please remove it to maintain consistent code style.

Suggested change
let kernel = self.kernel_primitives.get(kernel_call.kernel_id);
let kernel = self.kernel_primitives.get(kernel_call.kernel_id);

let num_parallel = kernel_call.num_parallel;
let is_broadcast = &kernel_call.is_broadcast;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This empty line can be removed to improve code conciseness.

let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()];
let mut chunk_sizes: Vec<Option<usize>> = vec![None; kernel.io_specs().len()];
for (((input, &ib), ir_inputs), chunk_size) in kernel_call.input_handles
.iter()
.zip(is_broadcast.iter())
.zip(ir_inputs_all.iter_mut())
.zip(chunk_sizes.iter_mut())
{
if input.is_none() {
continue;
}
let handle = ensure_handle(input.clone());
let values = handle
.shape_history
.permute_vec(&self.device_memories[handle.id].values);
if !ib {
*chunk_size = Some(values.len() / num_parallel);
}
*ir_inputs = values;
}
let mut ir_inputs_per_parallel = Vec::new();
for parallel_i in 0..num_parallel {
let mut ir_inputs = vec![SIMDField::<C>::zero(); kernel.ir_for_calling().input_size()];
for (i, ((input, input_start), input_end)) in kernel_call.input_handles
.iter()
.zip(kernel.ir_input_offsets().iter())
.zip(kernel.ir_input_offsets().iter().skip(1))
.enumerate()
{
if input.is_none() {
continue;
}
self.ir_copy_from_device_memory(
&ir_inputs_all[i],
&mut ir_inputs[*input_start..*input_end],
is_broadcast[i],
parallel_i,
chunk_sizes[i],
);
}
ir_inputs_per_parallel.push(ir_inputs);
}
let ir_outputs_per_parallel: Vec<Result<Vec<SIMDField<C>>, Error>> = ir_inputs_per_parallel
.into_par_iter()
.map(|ir_inputs| {
kernel
.ir_for_calling()
.eval_safe_simd(ir_inputs, &[], &self.hint_caller)
})
.collect();

let mut outputs_tmp: Vec<Vec<SIMDField::<C>>> = vec![Vec::new(); kernel.io_specs().len()];
for ir_outputs in ir_outputs_per_parallel {
let ir_outputs = ir_outputs?;
for (((spec, output_start), output_end), out) in kernel
.io_specs()
.iter()
.zip(kernel.ir_output_offsets().iter())
.zip(kernel.ir_output_offsets().iter().skip(1))
.zip(outputs_tmp.iter_mut())
{
if !spec.is_output {
continue;
}
out.extend_from_slice(&ir_outputs[*output_start..*output_end]);
}
}

for ((output, spec), ov) in kernel_call.output_handles
.iter()
.zip(kernel.io_specs().iter())
.zip(outputs_tmp.into_iter())
{
if !spec.is_output {
continue;
}
let output_id = output.as_ref().unwrap().id;
self.device_memories[output_id].values = ov;
}
}
Comment on lines +688 to +759
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code for kernel execution, which was moved into solve_witness, shares a lot of similar logic with the hint-solving part that follows (starting from line 693).

Specifically, the preparation of inputs for parallel execution is duplicated. To improve maintainability and reduce code duplication, consider refactoring the common logic for preparing inputs into a separate helper function. This would make the solve_witness function cleaner and easier to understand.


for (kernel_call, proof_template) in
self.kernel_calls.iter().zip(self.proof_templates.iter())
Expand Down
13 changes: 8 additions & 5 deletions expander_compiler/src/zkcuda/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,16 @@ fn context_shape_test_1_impl<P: ProvingSystem<M31Config>>() {

// Part 1
// Since we only use the shape [15, 1], the representation of the vector is "xxxxxxxxxxxxxxx.".
let mut a = ctx.copy_to_device(&vec![one; 15]);
let a_value_1 = vec![one; 15];
let (mut a, a_id_1) = ctx.new_device_memory(vec![15]);
call_kernel!(ctx, identity_1, 15, mut a).unwrap();
assert_eq!(ctx.copy_to_host::<Vec<F>>(a), vec![one; 15]);

// Part 2
// Since we use [15, 1] and [3, 5], the context will find a representation that is compatible with both.
// The representation of the vector is "xxxxx...xxxxx...xxxxx...........".
let mut a = ctx.copy_to_device(&vec![one; 15]);
let a_value_2 = vec![one; 15];
let (mut a, a_id_2) = ctx.new_device_memory(vec![15]);
let mut b = a.reshape(&[5, 3]);
call_kernel!(ctx, identity_1, 15, mut a).unwrap();
call_kernel!(ctx, identity_3, 5, mut b).unwrap();
Expand All @@ -84,6 +86,8 @@ fn context_shape_test_1_impl<P: ProvingSystem<M31Config>>() {
assert_eq!(ctx.copy_to_host::<Vec<F>>(b), vec![one; 15]);

let computation_graph = ctx.compile_computation_graph().unwrap();
ctx.copy_to_device(&a_value_1, a_id_1);
ctx.copy_to_device(&a_value_2, a_id_2);
ctx.solve_witness().unwrap();

// Debugging output and assertions
Expand Down Expand Up @@ -143,12 +147,11 @@ fn context_shape_test_1() {
fn context_shape_test_2() {
type C = M31Config;
type F = CircuitField<C>;
let one = F::one();
let identity_3 = compile_identity_3::<C>().unwrap();
let identity_5 = compile_identity_5::<C>().unwrap();

let mut ctx: Context<C> = Context::default();
let a = ctx.copy_to_device(&vec![one; 15]);
let (a, _) = ctx.new_device_memory(vec![15]);
let mut b = a.reshape(&[5, 3]);
let mut a = a.reshape(&[3, 5]);
call_kernel!(ctx, identity_5, 3, mut a).unwrap();
Expand All @@ -164,7 +167,7 @@ fn context_shape_test_2_success() {
let identity_5 = compile_identity_5::<C>().unwrap();

let mut ctx: Context<C> = Context::default();
let a = ctx.copy_to_device(&vec![one; 15]);
let (a, _) = ctx.new_device_memory(vec![15]);
let b = a.reshape(&[5, 3]);
let mut a = a.reshape(&[3, 5]);
call_kernel!(ctx, identity_5, 3, mut a).unwrap();
Expand Down
4 changes: 3 additions & 1 deletion expander_compiler/tests/cg_mpi_share.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ fn get_computation_graph() -> ComputationGraph<M31Config> {
}

println!("prepare data ok");
let p = ctx.copy_to_device(&p);
let p_value = p;
let (p, p_id) = ctx.new_device_memory(vec![N_PARALLEL, 64 * 8]);
println!("copy to device ok");
let mut out = None;
call_kernel!(ctx, kernel, N_PARALLEL, p, mut out).unwrap();
Expand All @@ -338,6 +339,7 @@ fn get_computation_graph() -> ComputationGraph<M31Config> {
assert_eq!(out[0][0], expected_res[0][0]);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The refactoring to deferred execution has broken the logic of this test. The assert_eq! calls will now fail because out is checked before solve_witness is called, meaning it contains uninitialized data.

These assertions should be moved to after the ctx.solve_witness().unwrap() call on line 342.


let computation_graph = ctx.compile_computation_graph().unwrap();
ctx.copy_to_device(&p_value, p_id);

computation_graph
}
Expand Down
Loading
Loading