|
| 1 | +use expander_compiler::frontend::*; |
| 2 | +use expander_compiler::zkcuda::proving_system::expander::config::ZKCudaBN254KZGBatchPCS; |
| 3 | +use expander_compiler::zkcuda::proving_system::{ExpanderNoOverSubscribe, ProvingSystem}; |
| 4 | +use expander_compiler::zkcuda::shape::Reshape; |
| 5 | +use expander_compiler::zkcuda::{context::*, kernel::*}; |
| 6 | + |
| 7 | +#[kernel] |
| 8 | +fn add_2_macro<C: Config>(api: &mut API<C>, a: &[InputVariable; 2], b: &mut OutputVariable) { |
| 9 | + *b = api.add(a[0], a[1]); |
| 10 | +} |
| 11 | + |
| 12 | +#[kernel] |
| 13 | +fn add_16_macro<C: Config>(api: &mut API<C>, a: &[InputVariable; 16], b: &mut OutputVariable) { |
| 14 | + let mut sum = api.constant(0); |
| 15 | + for i in 0..16 { |
| 16 | + sum = api.add(sum, a[i]); |
| 17 | + } |
| 18 | + *b = sum; |
| 19 | +} |
| 20 | + |
| 21 | +fn test_bn254_load_graph_with_new_data_impl<C: Config, P: ProvingSystem<C>>() { |
| 22 | + let kernel_add_2: KernelPrimitive<C> = compile_add_2_macro().unwrap(); |
| 23 | + let kernel_add_16: KernelPrimitive<C> = compile_add_16_macro().unwrap(); |
| 24 | + |
| 25 | + println!("\n===== First execution: create and save graph (BN254) ====="); |
| 26 | + let mut ctx1: Context<C> = Context::default(); |
| 27 | + |
| 28 | + // First set of input data (BN254 field elements) |
| 29 | + let mut a1: Vec<Vec<CircuitField<C>>> = vec![]; |
| 30 | + for i in 0..16 { |
| 31 | + a1.push(vec![]); |
| 32 | + for j in 0..2 { |
| 33 | + a1[i].push(CircuitField::<C>::from((i * 2 + j + 1) as u32)); |
| 34 | + } |
| 35 | + } |
| 36 | + let a1 = ctx1.copy_to_device(&a1); |
| 37 | + let mut b1: DeviceMemoryHandle = None; |
| 38 | + call_kernel!(ctx1, kernel_add_2, 16, a1, mut b1).unwrap(); |
| 39 | + let b1 = b1.reshape(&[1, 16]); |
| 40 | + let mut c1: DeviceMemoryHandle = None; |
| 41 | + call_kernel!(ctx1, kernel_add_16, 1, b1, mut c1).unwrap(); |
| 42 | + let c1 = c1.reshape(&[]); |
| 43 | + let result1: CircuitField<C> = ctx1.copy_to_host(c1); |
| 44 | + println!("First result: {:?}", result1); |
| 45 | + assert_eq!(result1, CircuitField::<C>::from(32 * 33 / 2 as u32)); |
| 46 | + |
| 47 | + let computation_graph = ctx1.compile_computation_graph().unwrap(); |
| 48 | + ctx1.solve_witness().unwrap(); |
| 49 | + println!("Starting setup (may take some time)..."); |
| 50 | + let (prover_setup, verifier_setup) = P::setup(&computation_graph); |
| 51 | + println!("Starting prove..."); |
| 52 | + let proof1 = P::prove( |
| 53 | + &prover_setup, |
| 54 | + &computation_graph, |
| 55 | + ctx1.export_device_memories(), |
| 56 | + ); |
| 57 | + println!("Starting verify..."); |
| 58 | + assert!(P::verify(&verifier_setup, &computation_graph, &proof1)); |
| 59 | + println!("First verification passed!"); |
| 60 | + |
| 61 | + println!("\n===== Second execution: call_kernel first (new BN254 data), then load_graph ====="); |
| 62 | + let mut ctx2: Context<C> = Context::default(); |
| 63 | + |
| 64 | + // Second set of input data (different BN254 field elements) |
| 65 | + let mut a2: Vec<Vec<CircuitField<C>>> = vec![]; |
| 66 | + for i in 0..16 { |
| 67 | + a2.push(vec![]); |
| 68 | + for j in 0..2 { |
| 69 | + // Use different values: starting from 1000 |
| 70 | + a2[i].push(CircuitField::<C>::from((i * 2 + j + 1000) as u32)); |
| 71 | + } |
| 72 | + } |
| 73 | + let a2 = ctx2.copy_to_device(&a2); |
| 74 | + |
| 75 | + // Call kernels first (same order as the first time) |
| 76 | + let mut b2: DeviceMemoryHandle = None; |
| 77 | + println!("Calling first kernel (using new data)..."); |
| 78 | + call_kernel!(ctx2, kernel_add_2, 16, a2, mut b2).unwrap(); |
| 79 | + |
| 80 | + let b2 = b2.reshape(&[1, 16]); |
| 81 | + let mut c2: DeviceMemoryHandle = None; |
| 82 | + println!("Calling second kernel..."); |
| 83 | + call_kernel!(ctx2, kernel_add_16, 1, b2, mut c2).unwrap(); |
| 84 | + |
| 85 | + let c2 = c2.reshape(&[]); |
| 86 | + let result2: CircuitField<C> = ctx2.copy_to_host(c2); |
| 87 | + println!("Second computation result: {:?}", result2); |
| 88 | + |
| 89 | + // Verify results are indeed different |
| 90 | + assert_ne!(result1, result2, "The two results should be different"); |
| 91 | + |
| 92 | + // Expected result for the second run: |
| 93 | + // Input: [1000,1001], [1002,1003], ..., [1030,1031] (32 numbers total) |
| 94 | + // add_2: 2001, 2005, 2009, ..., 2061 (16 numbers) |
| 95 | + // add_16: sum(2001, 2005, ..., 2061) = 16 * (2001 + 2061) / 2 = 32496 |
| 96 | + let expected2 = CircuitField::<C>::from(32496u32); |
| 97 | + assert_eq!(result2, expected2, "Second result should be 32496"); |
| 98 | + |
| 99 | + // Now load the graph (reuse compiled kernels) |
| 100 | + println!("Loading computation_graph..."); |
| 101 | + ctx2.load_computation_graph(computation_graph.clone()) |
| 102 | + .unwrap(); |
| 103 | + println!("Graph loaded successfully!"); |
| 104 | + |
| 105 | + // solve_witness (will recalculate using new data) |
| 106 | + println!("solve_witness (recalculating witness)..."); |
| 107 | + ctx2.solve_witness().unwrap(); |
| 108 | + println!("solve_witness succeeded!"); |
| 109 | + |
| 110 | + // prove (using new data) |
| 111 | + println!("prove (generating proof with new data)..."); |
| 112 | + let proof2 = P::prove( |
| 113 | + &prover_setup, |
| 114 | + &computation_graph, |
| 115 | + ctx2.export_device_memories(), |
| 116 | + ); |
| 117 | + println!("prove succeeded!"); |
| 118 | + |
| 119 | + // verify |
| 120 | + println!("verify (verifying proof with new data)..."); |
| 121 | + assert!(P::verify(&verifier_setup, &computation_graph, &proof2)); |
| 122 | + println!("✓ Second verification passed!"); |
| 123 | + println!("✓ Successfully generated and verified different proofs using new BN254 data"); |
| 124 | + println!(" - First result: {:?}", result1); |
| 125 | + println!(" - Second result: {:?}", result2); |
| 126 | + |
| 127 | + P::post_process(); |
| 128 | +} |
| 129 | + |
| 130 | +#[test] |
| 131 | +fn test_bn254_load_graph_with_new_data() { |
| 132 | + test_bn254_load_graph_with_new_data_impl::<_, ExpanderNoOverSubscribe<ZKCudaBN254KZGBatchPCS>>( |
| 133 | + ); |
| 134 | +} |
0 commit comments