Skip to content

Commit 1be709d

Browse files
committed
Merge branch 'hc/reuse_graph' into hc/schedule
2 parents 736a6e2 + ca286e5 commit 1be709d

2 files changed

Lines changed: 138 additions & 4 deletions

File tree

expander_compiler/src/zkcuda/context.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
577577

578578
let dm_shapes = self.propagate_and_get_shapes();
579579

580-
let (mut cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg {
580+
let (cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg {
581581
for (i, kernel) in cg.kernels.iter().enumerate() {
582582
assert_eq!(self.kernels.add(kernel), i);
583583
}
@@ -617,7 +617,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
617617
.collect::<Vec<_>>();
618618
let kernel_primitive = self.kernel_primitives.get(kernel_call.kernel_id);
619619
let kernel = if cg_kernels.is_some() {
620-
// 从已加载的 kernels 中通过 kernel_id 获取
620+
// Get kernel from loaded kernels by kernel_id
621621
self.kernels.get(kernel_call.kernel_id).clone()
622622
} else {
623623
let mut psi = Vec::new();
@@ -710,8 +710,8 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
710710
}
711711

712712
if let Some(_cg_kernels) = cg_kernels {
713-
// 不再检查 cg_kernels 是否为空,因为我们不再消耗它
714-
// kernels 已经在之前通过 self.kernels.add() 添加了
713+
// No longer checking if cg_kernels is empty since we no longer consume it
714+
// Kernels were already added earlier via self.kernels.add()
715715
assert_eq!(cg_proof_templates.unwrap(), self.proof_templates);
716716
assert_eq!(cg_commitments_lens.unwrap(), commitments_lens);
717717
Ok(None)
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
use expander_compiler::frontend::*;
2+
use expander_compiler::zkcuda::proving_system::expander::config::ZKCudaBN254KZGBatchPCS;
3+
use expander_compiler::zkcuda::proving_system::{ExpanderNoOverSubscribe, ProvingSystem};
4+
use expander_compiler::zkcuda::shape::Reshape;
5+
use expander_compiler::zkcuda::{context::*, kernel::*};
6+
7+
#[kernel]
8+
fn add_2_macro<C: Config>(api: &mut API<C>, a: &[InputVariable; 2], b: &mut OutputVariable) {
9+
*b = api.add(a[0], a[1]);
10+
}
11+
12+
#[kernel]
13+
fn add_16_macro<C: Config>(api: &mut API<C>, a: &[InputVariable; 16], b: &mut OutputVariable) {
14+
let mut sum = api.constant(0);
15+
for i in 0..16 {
16+
sum = api.add(sum, a[i]);
17+
}
18+
*b = sum;
19+
}
20+
21+
fn test_bn254_load_graph_with_new_data_impl<C: Config, P: ProvingSystem<C>>() {
22+
let kernel_add_2: KernelPrimitive<C> = compile_add_2_macro().unwrap();
23+
let kernel_add_16: KernelPrimitive<C> = compile_add_16_macro().unwrap();
24+
25+
println!("\n===== First execution: create and save graph (BN254) =====");
26+
let mut ctx1: Context<C> = Context::default();
27+
28+
// First set of input data (BN254 field elements)
29+
let mut a1: Vec<Vec<CircuitField<C>>> = vec![];
30+
for i in 0..16 {
31+
a1.push(vec![]);
32+
for j in 0..2 {
33+
a1[i].push(CircuitField::<C>::from((i * 2 + j + 1) as u32));
34+
}
35+
}
36+
let a1 = ctx1.copy_to_device(&a1);
37+
let mut b1: DeviceMemoryHandle = None;
38+
call_kernel!(ctx1, kernel_add_2, 16, a1, mut b1).unwrap();
39+
let b1 = b1.reshape(&[1, 16]);
40+
let mut c1: DeviceMemoryHandle = None;
41+
call_kernel!(ctx1, kernel_add_16, 1, b1, mut c1).unwrap();
42+
let c1 = c1.reshape(&[]);
43+
let result1: CircuitField<C> = ctx1.copy_to_host(c1);
44+
println!("First result: {:?}", result1);
45+
assert_eq!(result1, CircuitField::<C>::from(32 * 33 / 2 as u32));
46+
47+
let computation_graph = ctx1.compile_computation_graph().unwrap();
48+
ctx1.solve_witness().unwrap();
49+
println!("Starting setup (may take some time)...");
50+
let (prover_setup, verifier_setup) = P::setup(&computation_graph);
51+
println!("Starting prove...");
52+
let proof1 = P::prove(
53+
&prover_setup,
54+
&computation_graph,
55+
ctx1.export_device_memories(),
56+
);
57+
println!("Starting verify...");
58+
assert!(P::verify(&verifier_setup, &computation_graph, &proof1));
59+
println!("First verification passed!");
60+
61+
println!("\n===== Second execution: call_kernel first (new BN254 data), then load_graph =====");
62+
let mut ctx2: Context<C> = Context::default();
63+
64+
// Second set of input data (different BN254 field elements)
65+
let mut a2: Vec<Vec<CircuitField<C>>> = vec![];
66+
for i in 0..16 {
67+
a2.push(vec![]);
68+
for j in 0..2 {
69+
// Use different values: starting from 1000
70+
a2[i].push(CircuitField::<C>::from((i * 2 + j + 1000) as u32));
71+
}
72+
}
73+
let a2 = ctx2.copy_to_device(&a2);
74+
75+
// Call kernels first (same order as the first time)
76+
let mut b2: DeviceMemoryHandle = None;
77+
println!("Calling first kernel (using new data)...");
78+
call_kernel!(ctx2, kernel_add_2, 16, a2, mut b2).unwrap();
79+
80+
let b2 = b2.reshape(&[1, 16]);
81+
let mut c2: DeviceMemoryHandle = None;
82+
println!("Calling second kernel...");
83+
call_kernel!(ctx2, kernel_add_16, 1, b2, mut c2).unwrap();
84+
85+
let c2 = c2.reshape(&[]);
86+
let result2: CircuitField<C> = ctx2.copy_to_host(c2);
87+
println!("Second computation result: {:?}", result2);
88+
89+
// Verify results are indeed different
90+
assert_ne!(result1, result2, "The two results should be different");
91+
92+
// Expected result for the second run:
93+
// Input: [1000,1001], [1002,1003], ..., [1030,1031] (32 numbers total)
94+
// add_2: 2001, 2005, 2009, ..., 2061 (16 numbers)
95+
// add_16: sum(2001, 2005, ..., 2061) = 16 * (2001 + 2061) / 2 = 32496
96+
let expected2 = CircuitField::<C>::from(32496u32);
97+
assert_eq!(result2, expected2, "Second result should be 32496");
98+
99+
// Now load the graph (reuse compiled kernels)
100+
println!("Loading computation_graph...");
101+
ctx2.load_computation_graph(computation_graph.clone())
102+
.unwrap();
103+
println!("Graph loaded successfully!");
104+
105+
// solve_witness (will recalculate using new data)
106+
println!("solve_witness (recalculating witness)...");
107+
ctx2.solve_witness().unwrap();
108+
println!("solve_witness succeeded!");
109+
110+
// prove (using new data)
111+
println!("prove (generating proof with new data)...");
112+
let proof2 = P::prove(
113+
&prover_setup,
114+
&computation_graph,
115+
ctx2.export_device_memories(),
116+
);
117+
println!("prove succeeded!");
118+
119+
// verify
120+
println!("verify (verifying proof with new data)...");
121+
assert!(P::verify(&verifier_setup, &computation_graph, &proof2));
122+
println!("✓ Second verification passed!");
123+
println!("✓ Successfully generated and verified different proofs using new BN254 data");
124+
println!(" - First result: {:?}", result1);
125+
println!(" - Second result: {:?}", result2);
126+
127+
P::post_process();
128+
}
129+
130+
#[test]
131+
fn test_bn254_load_graph_with_new_data() {
132+
test_bn254_load_graph_with_new_data_impl::<_, ExpanderNoOverSubscribe<ZKCudaBN254KZGBatchPCS>>(
133+
);
134+
}

0 commit comments

Comments
 (0)