Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions expander_compiler/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ shared_memory.workspace = true
tiny-keccak.workspace = true
tokio.workspace = true
once_cell = "1.21.3"
stacker.workspace = true

[dev-dependencies]
rayon = "1.9"
Expand Down
82 changes: 41 additions & 41 deletions expander_compiler/src/zkcuda/context.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use core::num;

use arith::SimdField;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use serdes::ExpSerde;
Expand Down Expand Up @@ -44,7 +46,7 @@ pub struct KernelCall {
num_parallel: usize,
input_handles: Vec<DeviceMemoryHandle>,
output_handles: Vec<DeviceMemoryHandle>,
is_broadcast: Vec<bool>,
is_broadcast: Vec<usize>,
}

#[derive(PartialEq, Eq, Clone, Debug, ExpSerde)]
Expand All @@ -53,7 +55,7 @@ pub struct ProofTemplate {
pub commitment_indices: Vec<usize>,
pub commitment_bit_orders: Vec<BitOrder>,
pub parallel_count: usize,
pub is_broadcast: Vec<bool>,
pub is_broadcast: Vec<usize>,
}

impl ProofTemplate {
Expand All @@ -69,7 +71,7 @@ impl ProofTemplate {
pub fn parallel_count(&self) -> usize {
self.parallel_count
}
pub fn is_broadcast(&self) -> &[bool] {
pub fn is_broadcast(&self) -> &[usize] {
&self.is_broadcast
}
}
Expand Down Expand Up @@ -156,17 +158,19 @@ fn check_shape_compat(
kernel_shape: &Shape,
io_shape: &Shape,
parallel_count: usize,
) -> Option<bool> {
) -> Option<usize> {
if kernel_shape.len() == io_shape.len() {
if *kernel_shape == *io_shape {
Some(true)
Some(parallel_count)
} else {
None
}
} else if kernel_shape.len() + 1 == io_shape.len() {
if io_shape.iter().skip(1).eq(kernel_shape.iter()) {
if io_shape[0] == parallel_count {
Some(false)
Some(1)
} else if parallel_count % io_shape[0] == 0 {
Some(parallel_count / io_shape[0])
} else {
None
}
Expand Down Expand Up @@ -299,18 +303,15 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
&self,
values: &[SIMDField<C>],
s: &mut [SIMDField<C>],
is_broadcast: bool,
is_broadcast: usize,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The is_broadcast parameter is unused within ir_copy_from_device_memory and should be removed to improve code clarity. The function signature and all call sites (e.g., lines 418, 825, and 845) should be updated accordingly.

parallel_index: usize,
chunk_size: Option<usize>,
) {
if is_broadcast {
s.copy_from_slice(values);
} else {
let chunk_size = chunk_size.unwrap();
s.copy_from_slice(
&values[chunk_size * parallel_index..chunk_size * (parallel_index + 1)],
);
}
let chunk_size = chunk_size.unwrap();
let start_index = chunk_size * parallel_index % values.len();
s.copy_from_slice(
&values[start_index..(start_index + chunk_size)],
);
}

pub fn call_kernel(
Expand All @@ -332,7 +333,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
.enumerate()
{
if !spec.is_input {
is_broadcast.push(false);
is_broadcast.push(1);
continue;
}
/*println!(
Expand All @@ -350,7 +351,8 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
.as_ref()
.unwrap()
.shape_history
.get_initial_split_list(!ib);
.get_initial_split_list(ib/num_parallel+1);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The expression ib/num_parallel+1 is functionally correct but obscure. It distinguishes between full broadcast (ib == num_parallel) and other cases. Using a more explicit if expression would improve readability.

Suggested change
.get_initial_split_list(ib/num_parallel+1);
.get_initial_split_list(if ib == num_parallel { 0 } else { 1 });

// let isl = vec![1,64,4096];
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out line should be removed.

let t = io.as_ref().unwrap().id;
self.device_memories[t].required_shape_products = merge_shape_products(
&isl,
Expand All @@ -371,7 +373,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
}
}
for (io_spec, ib) in kernel.io_specs().iter().zip(is_broadcast.iter()) {
if io_spec.is_output && *ib {
if io_spec.is_output && *ib!=1 {
panic!("Output is broadcasted, but it shouldn't be");
}
}
Expand All @@ -381,11 +383,12 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
let mut outputs_tmp = vec![Vec::new(); kernel.io_specs().len()];
let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()];
let mut chunk_sizes: Vec<Option<usize>> = vec![None; kernel.io_specs().len()];
for (((input, &ib), ir_inputs), chunk_size) in ios
for ((((input, &ib), ir_inputs), chunk_size), kernel_shape) in ios
.iter()
.zip(is_broadcast.iter())
.zip(ir_inputs_all.iter_mut())
.zip(chunk_sizes.iter_mut())
.zip(kernel.io_shapes().iter())
{
if input.is_none() {
continue;
Expand All @@ -394,9 +397,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
let values = handle
.shape_history
.permute_vec(&self.device_memories[handle.id].values);
if !ib {
*chunk_size = Some(values.len() / num_parallel);
}
*chunk_size = Some(kernel_shape.iter().product());
*ir_inputs = values;
}
let mut ir_inputs_per_parallel = Vec::new();
Expand Down Expand Up @@ -469,7 +470,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
.as_ref()
.unwrap()
.shape_history
.get_initial_split_list(true),
.get_initial_split_list(1),
&self.device_memories[id].required_shape_products,
);
*output = handle.clone();
Expand Down Expand Up @@ -513,7 +514,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
.zip(kernel_call.input_handles.iter())
.zip(kernel_call.is_broadcast.iter())
{
if !spec.is_input || ib {
if !spec.is_input || ib > 1 {
continue;
}
let pad_shape = get_pad_shape(input_handle).unwrap();
Expand All @@ -526,7 +527,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
.zip(kernel_call.output_handles.iter())
.zip(kernel_call.is_broadcast.iter())
{
if !spec.is_output || ib {
if !spec.is_output || ib > 1 {
continue;
}
let pad_shape = get_pad_shape(output_handle).unwrap();
Expand All @@ -549,7 +550,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
if x != 1 && x != kernel_call.num_parallel {
let sh_tmp = handle.shape_history.reshape(&[x, total / x]);
dm.required_shape_products = merge_shape_products(
&sh_tmp.get_initial_split_list(true),
&sh_tmp.get_initial_split_list(1),
&dm.required_shape_products,
);
}
Expand All @@ -576,7 +577,6 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
self.state = ContextState::ComputationGraphDone;

let dm_shapes = self.propagate_and_get_shapes();

let (mut cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg {
for (i, kernel) in cg.kernels.iter().enumerate() {
assert_eq!(self.kernels.add(kernel), i);
Expand Down Expand Up @@ -622,11 +622,11 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
let mut psi = Vec::new();
for (s, &ib) in pad_shapes_input.iter().zip(kernel_call.is_broadcast.iter()) {
psi.push(s.as_ref().map(|t| {
if ib {
if ib == kernel_call.num_parallel {
t.0.clone()
} else {
keep_shape_since(&t.0, kernel_call.num_parallel)
}
} else{
keep_shape_since(&t.0, kernel_call.num_parallel/ib)
}
}));
}
let mut pso = Vec::new();
Expand All @@ -635,7 +635,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
.zip(kernel_call.is_broadcast.iter())
{
pso.push(s.as_ref().map(|t| {
if ib {
if ib == kernel_call.num_parallel {
t.0.clone()
} else {
keep_shape_since(&t.0, kernel_call.num_parallel)
Expand All @@ -661,7 +661,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
commitment_indices.push(handle.as_ref().unwrap().id);
commitment_bit_orders.push(shape.1.clone());
is_broadcast.push(ib);
if !ib {
if ib == 1 {
any_shape = Some(shape.0.clone());
}
}
Expand All @@ -678,7 +678,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
commitment_indices.push(handle.as_ref().unwrap().id);
commitment_bit_orders.push(shape.1.clone());
is_broadcast.push(ib);
if !ib {
if ib == 1 {
any_shape = Some(shape.0.clone());
}
}
Expand All @@ -695,7 +695,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
dm_max += 1;
commitment_bit_orders.push((0..n.trailing_zeros() as usize).collect());
commitments_lens.push(n);
is_broadcast.push(false);
is_broadcast.push(1);
}

let kernel_id = self.kernels.add(&kernel);
Expand Down Expand Up @@ -778,9 +778,8 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
let values = handle
.shape_history
.permute_vec(&self.device_memories[handle.id].values);
if !ib {
*chunk_size = Some(values.len() / kernel_call.num_parallel);
}
let kernel_shape = handle.shape_history.shape();
*chunk_size = Some(kernel_shape.iter().product());
*ir_inputs = values;
}
for (((output, &ib), ir_inputs), chunk_size) in kernel_call
Expand All @@ -800,7 +799,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
let values = handle
.shape_history
.permute_vec(&self.device_memories[handle.id].values);
assert!(!ib);
assert!(ib == 1);
*chunk_size = Some(values.len() / kernel_call.num_parallel);
*ir_inputs = values;
}
Expand All @@ -823,7 +822,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
self.ir_copy_from_device_memory(
ir_inputs,
&mut inputs[*input_start..*input_end],
chunk_size.is_none(),
chunk_size.unwrap_or(2),
parallel_i,
*chunk_size,
);
Expand All @@ -843,7 +842,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
self.ir_copy_from_device_memory(
ir_outputs,
&mut inputs[*output_start..*output_end],
chunk_size.is_none(),
chunk_size.unwrap_or(1),
parallel_i,
*chunk_size,
);
Expand Down Expand Up @@ -891,6 +890,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
.map(|dm| {
let shape = prefix_products_to_shape(&dm.required_shape_products);
let im = shape_padded_mapping(&shape);
let tmp = im.map_inputs(&dm.values);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable tmp is assigned but never used. It should be removed.

im.map_inputs(&dm.values)
})
.collect()
Expand Down
3 changes: 1 addition & 2 deletions expander_compiler/src/zkcuda/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,8 @@ fn reorder_ir_inputs<C: Config>(
lc_in[i].len = n;
assert!(var_max % n == 0);
let im = shape_padded_mapping(&pad_shapes[i]);
// println!("{:?}", im.mapping());
for (j, &k) in im.mapping().iter().enumerate() {
var_new_id[prev + k + 1] = var_max + j + 1;
var_new_id[prev + j + 1] = var_max + k + 1;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line appears to have incorrect logic for remapping variable IDs. var_new_id should map from the old variable ID to the new one. The im.mapping() provides a mapping from a padded index j to an original index k. The old variable ID is based on k, and the new ID is based on j.

The current implementation var_new_id[prev + j + 1] = var_max + k + 1; attempts to use a padded index j to index var_new_id, which is sized for original, unpadded inputs. This will likely lead to an out-of-bounds panic when padding is active (padded_len > original_len).

The logic should be var_new_id[old_id] = new_id, which translates to var_new_id[prev + k + 1] = var_max + j + 1;.

Suggested change
var_new_id[prev + j + 1] = var_max + k + 1;
var_new_id[prev + k + 1] = var_max + j + 1;

}
var_max += n;
}
Expand Down
2 changes: 1 addition & 1 deletion expander_compiler/src/zkcuda/mpi_mem_share.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl MPISharedMemory for ProofTemplate {
.map(|_| BitOrder::new_from_memory(ptr))
.collect();
let parallel_count = usize::new_from_memory(ptr);
let is_broadcast = Vec::<bool>::new_from_memory(ptr);
let is_broadcast = Vec::<usize>::new_from_memory(ptr);

ProofTemplate {
kernel_id,
Expand Down
Loading
Loading