From 7c1b854f60a48c72cf665c21d8c8e13f11fc6a64 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 19 May 2026 02:38:58 +0000 Subject: [PATCH 01/33] feat(bb/msm): workgroup-scan fused round kernel + PackedField primitive layer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Foundation for msm_webgpu_v2 rewrite. Two new files; no existing code modified beyond shader_manager.ts and the auto-regenerated wgsl/_generated/shaders.ts. - wgsl/struct/packed_field.template.wgsl — new partial defining struct PackedField { lo: vec4, hi: vec4 } and the wrappers mont_p, fr_add_p, fr_sub_p, fr_neg_p, fr_inv_p, plus field_load_ro/_rw, field_store, is_zero_packed, eq_packed, get_p_packed, get_r_packed, get_zero_packed. Each primitive body is a 3-line unpack-call-pack wrapper around the existing BigInt-limb implementations. The 20x13-bit BigInt representation only ever exists transiently inside these wrappers. - wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl — new fused round kernel. TPB threads cooperating on BATCH_SIZE = TPB*BS pairs per workgroup with one fr_inv_by_a per workgroup. Direct port of the bench_batch_affine design (validated at 22 ns/pair on M2) with bucket-indirect loads/stores via pair_target_meta. Every field-element variable is PackedField. The scheduler-enforced distinct-buckets invariant means the phase-D scatter has zero intra-workgroup RAW hazards. - cuzk/shader_manager.ts — adds gen_batch_affine_fused_wg_scan_shader. Validates TPB is a power of two for Hillis-Steele. Not in this PR: msm_webgpu_v2 directory fork, packed ports of BPR / horner / finalize / convert, host-side dispatch rewrite, e2e bench harness. Those land in follow-up PRs once this layer is validated. Validation done: yarn generate:wgsl regenerates cleanly, ShaderManager end-to-end render produces 147 KB of WGSL with all mustache tags substituted and every expected symbol present. Validation NOT done: WebGPU compilation (no Playwright/SwiftShader in this container), correctness vs CPU oracle, real-hardware perf — all needs the operator. --- .../ts/src/msm_webgpu/cuzk/shader_manager.ts | 53 +++ .../src/msm_webgpu/wgsl/_generated/shaders.ts | 374 +++++++++++++++++- .../batch_affine_fused_wg_scan.template.wgsl | 237 +++++++++++ .../wgsl/struct/packed_field.template.wgsl | 131 ++++++ 4 files changed, 794 insertions(+), 1 deletion(-) create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/struct/packed_field.template.wgsl diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index 6c97ef584582..05948ff3f852 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -10,6 +10,7 @@ import { batch_affine_finalize_apply as batch_affine_finalize_apply_shader, batch_affine_finalize_collect as batch_affine_finalize_collect_shader, batch_affine_fused_revcarry as batch_affine_fused_revcarry_shader, + batch_affine_fused_wg_scan as batch_affine_fused_wg_scan_shader, batch_affine_init as batch_affine_init_shader, batch_affine_schedule as batch_affine_schedule_shader, batch_inverse as batch_inverse_shader, @@ -42,6 +43,7 @@ import { mont_pro_product_f32_22_sos3uv3 as montgomery_product_f32_22_sos3uv3_funcs, mont_pro_product_karat_yuval as montgomery_product_karat_yuval_funcs, mulhilo_22 as mulhilo_22_funcs, + packed_field as packed_field_funcs, smvp_bn254 as smvp_bn254_shader, smvp_tree_entry_bucket_id as smvp_tree_entry_bucket_id_shader, smvp_tree_phase1 as smvp_tree_phase1_shader, @@ -877,6 +879,57 @@ ${packLines.join('\n')} ); } + /** + * Workgroup-scan fused batch-affine round kernel (v2). Mirrors the + * `bench_batch_affine` design (TPB threads cooperating over a + * BATCH_SIZE=TPB*BS pair slice with one fr_inv_by_a per workgroup) + * but with bucket-indirect loads via pair_target_meta. Every + * field-element variable is `PackedField`; no per-load unpack at + * the kernel level. + */ + public gen_batch_affine_fused_wg_scan_shader(tpb: number, bs: number): string { + if (tpb <= 0 || bs <= 0 || !Number.isInteger(tpb) || !Number.isInteger(bs)) { + throw new Error(`gen_batch_affine_fused_wg_scan_shader: tpb (${tpb}) and bs (${bs}) must be positive integers`); + } + if ((tpb & (tpb - 1)) !== 0) { + throw new Error(`gen_batch_affine_fused_wg_scan_shader: tpb (${tpb}) must be a power of two (Hillis-Steele scan)`); + } + const batch_size = tpb * bs; + const dec = this.decoupledPackUnpackWgsl(); + return mustache.render( + batch_affine_fused_wg_scan_shader, + { + tpb, + bs, + batch_size, + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, + dec_unpack: dec.unpack, + dec_pack: dec.pack, + recompile: this.recompile, + }, + { + structs, + bigint_funcs, + montgomery_product_funcs: this.mont_product_src, + field_funcs, + fr_pow_funcs, + bigint_by_funcs, + by_inverse_a_funcs, + packed_field_funcs, + }, + ); + } + public gen_batch_affine_finalize_shader(workgroup_size: number, num_csr_cols: number): string { return mustache.render( batch_affine_finalize_shader, diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index 741d3480f737..d09c0dea0f1a 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1,6 +1,6 @@ // AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT. // Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate. -// 48 shader sources inlined. +// 50 shader sources inlined. /* eslint-disable */ @@ -2660,6 +2660,245 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { } `; +export const batch_affine_fused_wg_scan = `{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +{{> packed_field_funcs }} + +// Workgroup-scan fused batch-affine round kernel for v2 MSM. +// +// Each workgroup of TPB threads cooperates on BATCH_SIZE = TPB * BS pairs +// from one subtask's pair pool, performs a workgroup-level Hillis-Steele +// prefix product over per-thread chunks, runs ONE fr_inv_by_a per +// workgroup, then back-walks per-thread emitting lean affine adds. This +// is the design validated in \`bench_batch_affine.template.wgsl\` (22 +// ns/pair at TPB=64, BS=16 on M2) with bucket-indirect loads/stores via +// \`pair_target_meta\`. +// +// LAYOUT +// - All field-element variables (workgroup, function, struct fields) +// are \`PackedField\` (two vec4). The 20×13-bit BigInt limb form +// only exists as a transient local inside mont_p / fr_*_p / fr_inv_p. +// - Per-subtask pair pool of length n (= count_buf[subtask_idx]) is +// dispatched as ceil(n / BATCH_SIZE) workgroups in X, num_subtasks +// in Z. The last workgroup of each subtask may have a partial batch +// (n - batch_base < BATCH_SIZE); threads with chunk_start >= +// batch_len contribute identity (R in Mont form) to the scan and +// skip phase D. +// +// PHASES +// A) Per-thread serial chunk: walk BS pairs, compute dx = Q.x - P.x +// and the inclusive prefix product. Captures block_total in a +// register, writes the per-element prefix into prefix_buf. +// B) Workgroup Hillis-Steele forward + backward scan over the TPB +// block_totals (log2 TPB rounds of mont mul). +// C) Thread 0 inverts the global product via fr_inv_by_a (ONE per +// workgroup). Broadcasts to wg_inv_total. +// D) Each thread back-walks its chunk, recovers inv_dx for each pair +// from (wg_inv_total * block_excl_prefix * block_excl_suffix * +// prev_in_chunk_prefix), emits lean affine add, scatters to +// running_x/y[bucket]. +// +// SAFETY +// The scheduler emits at most one pair per (subtask, bucket) per round +// (see batch_affine_schedule). So within a workgroup's BATCH_SIZE +// slots, every \`bucket\` is distinct → no intra-workgroup RAW hazards +// on the running_x/y scatters. Across workgroups in the same subtask: +// disjoint slot ranges → still distinct buckets. Across subtasks +// (Z dim): different bucket ranges entirely. + +const TPB: u32 = {{ tpb }}u; +const BS: u32 = {{ bs }}u; +const BATCH_SIZE: u32 = {{ batch_size }}u; + +@group(0) @binding(0) +var val_idx: array; +@group(0) @binding(1) +var new_point_x: array>; +@group(0) @binding(2) +var new_point_y: array>; +@group(0) @binding(3) +var running_x: array>; +@group(0) @binding(4) +var running_y: array>; +@group(0) @binding(5) +var pair_target_meta: array; +@group(0) @binding(6) +var prefix_buf: array>; +@group(0) @binding(7) +var count_buf: array>; + +// params[0] = num_columns (per-subtask pool stride) +// params[1] = input_size (per-subtask val_idx stride) +@group(0) @binding(8) +var params: vec4; + +var wg_fwd: array; +var wg_bwd: array; +var wg_inv_total: PackedField; + +@compute +@workgroup_size({{ tpb }}) +fn main( + @builtin(local_invocation_id) lid: vec3, + @builtin(workgroup_id) wid: vec3, +) { + let tid = lid.x; + let wg_idx = wid.x; + let subtask_idx = wid.z; + let num_columns = params[0]; + let input_size = params[1]; + + let n = atomicLoad(&count_buf[subtask_idx]); + let batch_base = wg_idx * BATCH_SIZE; + if (batch_base >= n) { + return; + } + + let pool_base = subtask_idx * num_columns; + let vi_offset = subtask_idx * input_size; + + let remaining = n - batch_base; + let batch_len = min(BATCH_SIZE, remaining); + + let chunk_start = tid * BS; + var chunk_len: u32 = 0u; + if (chunk_start < batch_len) { + let chunk_end_unclamped = chunk_start + BS; + let chunk_end = min(chunk_end_unclamped, batch_len); + chunk_len = chunk_end - chunk_start; + } + + // Phase A — per-thread serial prefix product. Threads with + // chunk_len == 0 contribute identity (R = Mont 1) so the workgroup + // scan reads a sane value for every slot. + var block_total: PackedField = get_r_packed(); + if (chunk_len > 0u) { + let k0 = chunk_start; + let slot0 = pool_base + batch_base + k0; + let bucket0 = pair_target_meta[2u * slot0]; + let cursor0 = pair_target_meta[2u * slot0 + 1u]; + let pt_idx0 = val_idx[vi_offset + cursor0]; + let p_x0 = field_load_rw(bucket0, &running_x); + let q_x0 = field_load_ro(pt_idx0, &new_point_x); + let dx0 = fr_sub_p(q_x0, p_x0); + field_store(pool_base + batch_base + k0, &prefix_buf, dx0); + block_total = dx0; + + for (var i: u32 = 1u; i < BS; i = i + 1u) { + if (i >= chunk_len) { break; } + let k = chunk_start + i; + let slot = pool_base + batch_base + k; + let bucket = pair_target_meta[2u * slot]; + let cursor = pair_target_meta[2u * slot + 1u]; + let pt_idx = val_idx[vi_offset + cursor]; + let p_x = field_load_rw(bucket, &running_x); + let q_x = field_load_ro(pt_idx, &new_point_x); + let dx = fr_sub_p(q_x, p_x); + block_total = mont_p(block_total, dx); + field_store(pool_base + batch_base + k, &prefix_buf, block_total); + } + } + + wg_fwd[tid] = block_total; + wg_bwd[tid] = block_total; + workgroupBarrier(); + + // Phase B — Hillis-Steele forward + backward inclusive scan. + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var fwd_x: PackedField = wg_fwd[tid]; + if (tid >= stride) { + let lhs = wg_fwd[tid - stride]; + fwd_x = mont_p(lhs, fwd_x); + } + var bwd_x: PackedField = wg_bwd[tid]; + if (tid + stride < TPB) { + let rhs = wg_bwd[tid + stride]; + bwd_x = mont_p(bwd_x, rhs); + } + workgroupBarrier(); + wg_fwd[tid] = fwd_x; + wg_bwd[tid] = bwd_x; + workgroupBarrier(); + } + + // Phase C — single fr_inv per workgroup. wg_fwd[TPB-1] holds the + // product of every active (and identity-padding) block_total in the + // workgroup. + if (tid == 0u) { + let global_total = wg_fwd[TPB - 1u]; + wg_inv_total = fr_inv_p(global_total); + } + workgroupBarrier(); + + // Phase D — back-walk this thread's chunk, emit lean affine adds. + if (chunk_len == 0u) { + return; + } + + var block_excl_prefix: PackedField = get_r_packed(); + if (tid > 0u) { + block_excl_prefix = wg_fwd[tid - 1u]; + } + var block_excl_suffix: PackedField = get_r_packed(); + if (tid + 1u < TPB) { + block_excl_suffix = wg_bwd[tid + 1u]; + } + var inv_acc: PackedField = mont_p(wg_inv_total, block_excl_prefix); + inv_acc = mont_p(inv_acc, block_excl_suffix); + + for (var off: u32 = 0u; off < BS; off = off + 1u) { + if (off >= chunk_len) { break; } + let k = chunk_start + (chunk_len - 1u - off); + let slot = pool_base + batch_base + k; + let bucket = pair_target_meta[2u * slot]; + let cursor = pair_target_meta[2u * slot + 1u]; + let pt_idx = val_idx[vi_offset + cursor]; + + let p_x = field_load_rw(bucket, &running_x); + let p_y = field_load_rw(bucket, &running_y); + let q_x = field_load_ro(pt_idx, &new_point_x); + let q_y = field_load_ro(pt_idx, &new_point_y); + + var inv_dx: PackedField; + if (k > chunk_start) { + let prev = field_load_rw(pool_base + batch_base + (k - 1u), &prefix_buf); + inv_dx = mont_p(inv_acc, prev); + } else { + inv_dx = inv_acc; + } + + let dy = fr_sub_p(q_y, p_y); + let lambda = mont_p(dy, inv_dx); + let lambda_sq = mont_p(lambda, lambda); + var r_x = fr_sub_p(lambda_sq, p_x); + r_x = fr_sub_p(r_x, q_x); + let dx_back = fr_sub_p(p_x, r_x); + let ldx = mont_p(lambda, dx_back); + let r_y = fr_sub_p(ldx, p_y); + + field_store(bucket, &running_x, r_x); + field_store(bucket, &running_y, r_y); + + if (k > chunk_start) { + let dx_fwd = fr_sub_p(q_x, p_x); + inv_acc = mont_p(inv_acc, dx_fwd); + } + } + + {{{ recompile }}} +} +`; + export const batch_affine_init = `{{> structs }} // Init kernel for the batch-affine SMVP pipeline. @@ -9644,6 +9883,139 @@ fn mulhilo2(a: vec2, b: vec2) -> vec4 { } `; +export const packed_field = `// Packed 256-bit field-element type and primitive wrappers for v2 MSM. +// +// A \`PackedField\` holds one canonical [0, q) BN254 base-field value as two +// vec4 (8 × u32 = 32 bytes little-endian). Storage buffers are +// \`array>\` with logical stride 2 vec4s per element. +// +// Design constraint (from the v2 plan): every shader-level field-element +// variable, struct field, workgroup-shared var, and binding is +// \`PackedField\`. The 20×13-bit \`BigInt\` representation only appears as a +// transient local inside the wrappers below. No kernel ever calls +// \`unpack256_to_limbs\` or \`pack_limbs_to_256\` directly. +// +// Cost per primitive call: ~2 unpacks + 1 pack on top of the underlying +// BigInt operation. On Apple M2 each pack/unpack is ~10 cycles vs ~100 +// cycles for \`montgomery_product\`, so chains of mont-muls pay <15% +// overhead vs the BigInt calling convention used by the legacy +// msm_webgpu/ shaders. +// +// PRECONDITION: this partial must be included after \`bigint_funcs\`, +// \`montgomery_product_funcs\`, \`field_funcs\`, \`by_inverse_a_funcs\`, and +// after the {{{ dec_unpack }}} / {{{ dec_pack }}} substitution blocks +// have rendered \`unpack256_to_limbs\` / \`pack_limbs_to_256\` into the +// shader. + +struct PackedField { + lo: vec4, + hi: vec4, +} + +fn pf_to_words(p: PackedField) -> array { + var w: array; + w[0] = p.lo.x; w[1] = p.lo.y; w[2] = p.lo.z; w[3] = p.lo.w; + w[4] = p.hi.x; w[5] = p.hi.y; w[6] = p.hi.z; w[7] = p.hi.w; + return w; +} + +fn pf_from_words(w0: u32, w1: u32, w2: u32, w3: u32, + w4: u32, w5: u32, w6: u32, w7: u32) -> PackedField { + var p: PackedField; + p.lo = vec4(w0, w1, w2, w3); + p.hi = vec4(w4, w5, w6, w7); + return p; +} + +fn unpack_field(p: PackedField) -> BigInt { + let w = pf_to_words(p); + return unpack256_to_limbs(w); +} + +fn pack_field(b: ptr) -> PackedField { + let w = pack_limbs_to_256(b); + return pf_from_words(w[0], w[1], w[2], w[3], w[4], w[5], w[6], w[7]); +} + +fn field_load_ro(idx: u32, src: ptr>, read>) -> PackedField { + var p: PackedField; + p.lo = (*src)[2u * idx]; + p.hi = (*src)[2u * idx + 1u]; + return p; +} + +fn field_load_rw(idx: u32, src: ptr>, read_write>) -> PackedField { + var p: PackedField; + p.lo = (*src)[2u * idx]; + p.hi = (*src)[2u * idx + 1u]; + return p; +} + +fn field_store(idx: u32, dst: ptr>, read_write>, val: PackedField) { + (*dst)[2u * idx] = val.lo; + (*dst)[2u * idx + 1u] = val.hi; +} + +fn is_zero_packed(a: PackedField) -> bool { + return all(a.lo == vec4(0u, 0u, 0u, 0u)) + && all(a.hi == vec4(0u, 0u, 0u, 0u)); +} + +fn eq_packed(a: PackedField, b: PackedField) -> bool { + return all(a.lo == b.lo) && all(a.hi == b.hi); +} + +fn get_zero_packed() -> PackedField { + return PackedField(vec4(0u), vec4(0u)); +} + +fn get_p_packed() -> PackedField { + var p: BigInt = get_p(); + return pack_field(&p); +} + +fn get_r_packed() -> PackedField { + var r: BigInt; +{{{ r_limbs }}} + return pack_field(&r); +} + +fn mont_p(a: PackedField, b: PackedField) -> PackedField { + var a_l = unpack_field(a); + var b_l = unpack_field(b); + var out = montgomery_product(&a_l, &b_l); + return pack_field(&out); +} + +fn fr_add_p(a: PackedField, b: PackedField) -> PackedField { + var a_l = unpack_field(a); + var b_l = unpack_field(b); + var out = fr_add(&a_l, &b_l); + return pack_field(&out); +} + +fn fr_sub_p(a: PackedField, b: PackedField) -> PackedField { + var a_l = unpack_field(a); + var b_l = unpack_field(b); + var out = fr_sub(&a_l, &b_l); + return pack_field(&out); +} + +fn fr_neg_p(a: PackedField) -> PackedField { + var a_l = unpack_field(a); + var p_l: BigInt = get_p(); + var out: BigInt; + let _b = bigint_sub(&p_l, &a_l, &out); + return pack_field(&out); +} + +fn fr_inv_p(a: PackedField) -> PackedField { + let a_l = unpack_field(a); + var out = fr_inv_by_a(a_l); + return pack_field(&out); +} +`; + export const structs = `struct Point { x: BigInt, y: BigInt, diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl new file mode 100644 index 000000000000..846dbede058a --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl @@ -0,0 +1,237 @@ +{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +{{> packed_field_funcs }} + +// Workgroup-scan fused batch-affine round kernel for v2 MSM. +// +// Each workgroup of TPB threads cooperates on BATCH_SIZE = TPB * BS pairs +// from one subtask's pair pool, performs a workgroup-level Hillis-Steele +// prefix product over per-thread chunks, runs ONE fr_inv_by_a per +// workgroup, then back-walks per-thread emitting lean affine adds. This +// is the design validated in `bench_batch_affine.template.wgsl` (22 +// ns/pair at TPB=64, BS=16 on M2) with bucket-indirect loads/stores via +// `pair_target_meta`. +// +// LAYOUT +// - All field-element variables (workgroup, function, struct fields) +// are `PackedField` (two vec4). The 20×13-bit BigInt limb form +// only exists as a transient local inside mont_p / fr_*_p / fr_inv_p. +// - Per-subtask pair pool of length n (= count_buf[subtask_idx]) is +// dispatched as ceil(n / BATCH_SIZE) workgroups in X, num_subtasks +// in Z. The last workgroup of each subtask may have a partial batch +// (n - batch_base < BATCH_SIZE); threads with chunk_start >= +// batch_len contribute identity (R in Mont form) to the scan and +// skip phase D. +// +// PHASES +// A) Per-thread serial chunk: walk BS pairs, compute dx = Q.x - P.x +// and the inclusive prefix product. Captures block_total in a +// register, writes the per-element prefix into prefix_buf. +// B) Workgroup Hillis-Steele forward + backward scan over the TPB +// block_totals (log2 TPB rounds of mont mul). +// C) Thread 0 inverts the global product via fr_inv_by_a (ONE per +// workgroup). Broadcasts to wg_inv_total. +// D) Each thread back-walks its chunk, recovers inv_dx for each pair +// from (wg_inv_total * block_excl_prefix * block_excl_suffix * +// prev_in_chunk_prefix), emits lean affine add, scatters to +// running_x/y[bucket]. +// +// SAFETY +// The scheduler emits at most one pair per (subtask, bucket) per round +// (see batch_affine_schedule). So within a workgroup's BATCH_SIZE +// slots, every `bucket` is distinct → no intra-workgroup RAW hazards +// on the running_x/y scatters. Across workgroups in the same subtask: +// disjoint slot ranges → still distinct buckets. Across subtasks +// (Z dim): different bucket ranges entirely. + +const TPB: u32 = {{ tpb }}u; +const BS: u32 = {{ bs }}u; +const BATCH_SIZE: u32 = {{ batch_size }}u; + +@group(0) @binding(0) +var val_idx: array; +@group(0) @binding(1) +var new_point_x: array>; +@group(0) @binding(2) +var new_point_y: array>; +@group(0) @binding(3) +var running_x: array>; +@group(0) @binding(4) +var running_y: array>; +@group(0) @binding(5) +var pair_target_meta: array; +@group(0) @binding(6) +var prefix_buf: array>; +@group(0) @binding(7) +var count_buf: array>; + +// params[0] = num_columns (per-subtask pool stride) +// params[1] = input_size (per-subtask val_idx stride) +@group(0) @binding(8) +var params: vec4; + +var wg_fwd: array; +var wg_bwd: array; +var wg_inv_total: PackedField; + +@compute +@workgroup_size({{ tpb }}) +fn main( + @builtin(local_invocation_id) lid: vec3, + @builtin(workgroup_id) wid: vec3, +) { + let tid = lid.x; + let wg_idx = wid.x; + let subtask_idx = wid.z; + let num_columns = params[0]; + let input_size = params[1]; + + let n = atomicLoad(&count_buf[subtask_idx]); + let batch_base = wg_idx * BATCH_SIZE; + if (batch_base >= n) { + return; + } + + let pool_base = subtask_idx * num_columns; + let vi_offset = subtask_idx * input_size; + + let remaining = n - batch_base; + let batch_len = min(BATCH_SIZE, remaining); + + let chunk_start = tid * BS; + var chunk_len: u32 = 0u; + if (chunk_start < batch_len) { + let chunk_end_unclamped = chunk_start + BS; + let chunk_end = min(chunk_end_unclamped, batch_len); + chunk_len = chunk_end - chunk_start; + } + + // Phase A — per-thread serial prefix product. Threads with + // chunk_len == 0 contribute identity (R = Mont 1) so the workgroup + // scan reads a sane value for every slot. + var block_total: PackedField = get_r_packed(); + if (chunk_len > 0u) { + let k0 = chunk_start; + let slot0 = pool_base + batch_base + k0; + let bucket0 = pair_target_meta[2u * slot0]; + let cursor0 = pair_target_meta[2u * slot0 + 1u]; + let pt_idx0 = val_idx[vi_offset + cursor0]; + let p_x0 = field_load_rw(bucket0, &running_x); + let q_x0 = field_load_ro(pt_idx0, &new_point_x); + let dx0 = fr_sub_p(q_x0, p_x0); + field_store(pool_base + batch_base + k0, &prefix_buf, dx0); + block_total = dx0; + + for (var i: u32 = 1u; i < BS; i = i + 1u) { + if (i >= chunk_len) { break; } + let k = chunk_start + i; + let slot = pool_base + batch_base + k; + let bucket = pair_target_meta[2u * slot]; + let cursor = pair_target_meta[2u * slot + 1u]; + let pt_idx = val_idx[vi_offset + cursor]; + let p_x = field_load_rw(bucket, &running_x); + let q_x = field_load_ro(pt_idx, &new_point_x); + let dx = fr_sub_p(q_x, p_x); + block_total = mont_p(block_total, dx); + field_store(pool_base + batch_base + k, &prefix_buf, block_total); + } + } + + wg_fwd[tid] = block_total; + wg_bwd[tid] = block_total; + workgroupBarrier(); + + // Phase B — Hillis-Steele forward + backward inclusive scan. + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var fwd_x: PackedField = wg_fwd[tid]; + if (tid >= stride) { + let lhs = wg_fwd[tid - stride]; + fwd_x = mont_p(lhs, fwd_x); + } + var bwd_x: PackedField = wg_bwd[tid]; + if (tid + stride < TPB) { + let rhs = wg_bwd[tid + stride]; + bwd_x = mont_p(bwd_x, rhs); + } + workgroupBarrier(); + wg_fwd[tid] = fwd_x; + wg_bwd[tid] = bwd_x; + workgroupBarrier(); + } + + // Phase C — single fr_inv per workgroup. wg_fwd[TPB-1] holds the + // product of every active (and identity-padding) block_total in the + // workgroup. + if (tid == 0u) { + let global_total = wg_fwd[TPB - 1u]; + wg_inv_total = fr_inv_p(global_total); + } + workgroupBarrier(); + + // Phase D — back-walk this thread's chunk, emit lean affine adds. + if (chunk_len == 0u) { + return; + } + + var block_excl_prefix: PackedField = get_r_packed(); + if (tid > 0u) { + block_excl_prefix = wg_fwd[tid - 1u]; + } + var block_excl_suffix: PackedField = get_r_packed(); + if (tid + 1u < TPB) { + block_excl_suffix = wg_bwd[tid + 1u]; + } + var inv_acc: PackedField = mont_p(wg_inv_total, block_excl_prefix); + inv_acc = mont_p(inv_acc, block_excl_suffix); + + for (var off: u32 = 0u; off < BS; off = off + 1u) { + if (off >= chunk_len) { break; } + let k = chunk_start + (chunk_len - 1u - off); + let slot = pool_base + batch_base + k; + let bucket = pair_target_meta[2u * slot]; + let cursor = pair_target_meta[2u * slot + 1u]; + let pt_idx = val_idx[vi_offset + cursor]; + + let p_x = field_load_rw(bucket, &running_x); + let p_y = field_load_rw(bucket, &running_y); + let q_x = field_load_ro(pt_idx, &new_point_x); + let q_y = field_load_ro(pt_idx, &new_point_y); + + var inv_dx: PackedField; + if (k > chunk_start) { + let prev = field_load_rw(pool_base + batch_base + (k - 1u), &prefix_buf); + inv_dx = mont_p(inv_acc, prev); + } else { + inv_dx = inv_acc; + } + + let dy = fr_sub_p(q_y, p_y); + let lambda = mont_p(dy, inv_dx); + let lambda_sq = mont_p(lambda, lambda); + var r_x = fr_sub_p(lambda_sq, p_x); + r_x = fr_sub_p(r_x, q_x); + let dx_back = fr_sub_p(p_x, r_x); + let ldx = mont_p(lambda, dx_back); + let r_y = fr_sub_p(ldx, p_y); + + field_store(bucket, &running_x, r_x); + field_store(bucket, &running_y, r_y); + + if (k > chunk_start) { + let dx_fwd = fr_sub_p(q_x, p_x); + inv_acc = mont_p(inv_acc, dx_fwd); + } + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/struct/packed_field.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/struct/packed_field.template.wgsl new file mode 100644 index 000000000000..c1346aa66d88 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/struct/packed_field.template.wgsl @@ -0,0 +1,131 @@ +// Packed 256-bit field-element type and primitive wrappers for v2 MSM. +// +// A `PackedField` holds one canonical [0, q) BN254 base-field value as two +// vec4 (8 × u32 = 32 bytes little-endian). Storage buffers are +// `array>` with logical stride 2 vec4s per element. +// +// Design constraint (from the v2 plan): every shader-level field-element +// variable, struct field, workgroup-shared var, and binding is +// `PackedField`. The 20×13-bit `BigInt` representation only appears as a +// transient local inside the wrappers below. No kernel ever calls +// `unpack256_to_limbs` or `pack_limbs_to_256` directly. +// +// Cost per primitive call: ~2 unpacks + 1 pack on top of the underlying +// BigInt operation. On Apple M2 each pack/unpack is ~10 cycles vs ~100 +// cycles for `montgomery_product`, so chains of mont-muls pay <15% +// overhead vs the BigInt calling convention used by the legacy +// msm_webgpu/ shaders. +// +// PRECONDITION: this partial must be included after `bigint_funcs`, +// `montgomery_product_funcs`, `field_funcs`, `by_inverse_a_funcs`, and +// after the {{{ dec_unpack }}} / {{{ dec_pack }}} substitution blocks +// have rendered `unpack256_to_limbs` / `pack_limbs_to_256` into the +// shader. + +struct PackedField { + lo: vec4, + hi: vec4, +} + +fn pf_to_words(p: PackedField) -> array { + var w: array; + w[0] = p.lo.x; w[1] = p.lo.y; w[2] = p.lo.z; w[3] = p.lo.w; + w[4] = p.hi.x; w[5] = p.hi.y; w[6] = p.hi.z; w[7] = p.hi.w; + return w; +} + +fn pf_from_words(w0: u32, w1: u32, w2: u32, w3: u32, + w4: u32, w5: u32, w6: u32, w7: u32) -> PackedField { + var p: PackedField; + p.lo = vec4(w0, w1, w2, w3); + p.hi = vec4(w4, w5, w6, w7); + return p; +} + +fn unpack_field(p: PackedField) -> BigInt { + let w = pf_to_words(p); + return unpack256_to_limbs(w); +} + +fn pack_field(b: ptr) -> PackedField { + let w = pack_limbs_to_256(b); + return pf_from_words(w[0], w[1], w[2], w[3], w[4], w[5], w[6], w[7]); +} + +fn field_load_ro(idx: u32, src: ptr>, read>) -> PackedField { + var p: PackedField; + p.lo = (*src)[2u * idx]; + p.hi = (*src)[2u * idx + 1u]; + return p; +} + +fn field_load_rw(idx: u32, src: ptr>, read_write>) -> PackedField { + var p: PackedField; + p.lo = (*src)[2u * idx]; + p.hi = (*src)[2u * idx + 1u]; + return p; +} + +fn field_store(idx: u32, dst: ptr>, read_write>, val: PackedField) { + (*dst)[2u * idx] = val.lo; + (*dst)[2u * idx + 1u] = val.hi; +} + +fn is_zero_packed(a: PackedField) -> bool { + return all(a.lo == vec4(0u, 0u, 0u, 0u)) + && all(a.hi == vec4(0u, 0u, 0u, 0u)); +} + +fn eq_packed(a: PackedField, b: PackedField) -> bool { + return all(a.lo == b.lo) && all(a.hi == b.hi); +} + +fn get_zero_packed() -> PackedField { + return PackedField(vec4(0u), vec4(0u)); +} + +fn get_p_packed() -> PackedField { + var p: BigInt = get_p(); + return pack_field(&p); +} + +fn get_r_packed() -> PackedField { + var r: BigInt; +{{{ r_limbs }}} + return pack_field(&r); +} + +fn mont_p(a: PackedField, b: PackedField) -> PackedField { + var a_l = unpack_field(a); + var b_l = unpack_field(b); + var out = montgomery_product(&a_l, &b_l); + return pack_field(&out); +} + +fn fr_add_p(a: PackedField, b: PackedField) -> PackedField { + var a_l = unpack_field(a); + var b_l = unpack_field(b); + var out = fr_add(&a_l, &b_l); + return pack_field(&out); +} + +fn fr_sub_p(a: PackedField, b: PackedField) -> PackedField { + var a_l = unpack_field(a); + var b_l = unpack_field(b); + var out = fr_sub(&a_l, &b_l); + return pack_field(&out); +} + +fn fr_neg_p(a: PackedField) -> PackedField { + var a_l = unpack_field(a); + var p_l: BigInt = get_p(); + var out: BigInt; + let _b = bigint_sub(&p_l, &a_l, &out); + return pack_field(&out); +} + +fn fr_inv_p(a: PackedField) -> PackedField { + let a_l = unpack_field(a); + var out = fr_inv_by_a(a_l); + return pack_field(&out); +} From 0ea5877a586b09cf44566d280ba867d4698e4048 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 19 May 2026 03:25:09 +0000 Subject: [PATCH 02/33] feat(bb/msm): standalone bench harness for batch_affine_fused_wg_scan MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Forks bench-batch-affine.{ts,html} as bench-fused-wg-scan.{ts,html} — exercises the new workgroup-scan fused round kernel against a synthetic flat pool (one subtask, bucket_i = i, cursor_i = i, val_idx[i] = i) with on-curve BN254 G1 affine pairs generated via noble. The output R_i = P_i + Q_i is decoded from packed Mont form and compared to noble's reference G1 add, so the bench is correctness-gated, not just perf-gated. - Sweep: BATCH_SIZE in {256, 512, 1024, 2048} at TOTAL_PAIRS = 65536 by default. Per-size correctness check followed by perf rep loop; ns/pair printed per size for the M2 comparison. - run-browserstack.mjs pageMap gets a new "bench-fused-wg-scan" entry so the BrowserStack runner can drive it via --page. No production code touched. Pure additive dev-page bench. Vite picks the new .html up automatically. --- .../dev/msm-webgpu/bench-fused-wg-scan.html | 37 ++ .../ts/dev/msm-webgpu/bench-fused-wg-scan.ts | 539 ++++++++++++++++++ .../msm-webgpu/scripts/run-browserstack.mjs | 1 + 3 files changed, 577 insertions(+) create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.html create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts diff --git a/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.html b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.html new file mode 100644 index 000000000000..9605a79b5605 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.html @@ -0,0 +1,37 @@ + + + + + Workgroup-scan fused batch-affine round bench (WebGPU) + + + +

Workgroup-scan fused batch-affine round bench (WebGPU)

+

Query params: ?reps=R&total=N&sizes=A,B,C&skip_correctness=1

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts new file mode 100644 index 000000000000..b998be6a4e0f --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts @@ -0,0 +1,539 @@ +/// +// Standalone WebGPU bench + correctness oracle for the new +// `batch_affine_fused_wg_scan` kernel (workgroup-scan fused batch-affine +// round, validated to mirror bench_batch_affine's 22 ns/pair design with +// bucket-indirect loads via pair_target_meta). +// +// Inputs: TOTAL_PAIRS on-curve BN254 G1 affine pairs (P_i, Q_i), +// distributed across ONE synthetic subtask. Each pair maps to a unique +// bucket (`pair_target_meta[2i] = i`, `pair_target_meta[2i+1] = i`, +// `val_idx[i] = i`), so the scheduler's "distinct-buckets within a +// subtask round" invariant holds trivially. The kernel writes the +// affine sum `R_i = P_i + Q_i` to running_x[i] / running_y[i]; we +// decode packed Mont form back to canonical and compare to noble's +// reference P.add(Q). +// +// SAFETY: NO MSM pipeline touched. The shader uses only compile-time- +// const loop bounds (BS, TPB, NUM_WORDS). One dispatch per measurement. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { makeResultsClient } from './results_post.js'; +import { bn254 } from '@noble/curves/bn254'; + +const G1 = bn254.G1.ProjectivePoint; +const FR_ORDER = bn254.fields.Fr.ORDER; + +const DEFAULT_TOTAL_PAIRS = 1 << 16; +let TOTAL_PAIRS = DEFAULT_TOTAL_PAIRS; + +const DEFAULT_BATCH_SIZES = [256, 512, 1024, 2048] as const; +let BATCH_SIZES: readonly number[] = DEFAULT_BATCH_SIZES; + +let SKIP_CORRECTNESS = false; + +function tpbBsFor(batchSize: number): { tpb: number; bs: number } { + if (batchSize <= 64) return { tpb: batchSize, bs: 1 }; + return { tpb: 64, bs: batchSize / 64 }; +} + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + for (;;) { + let v = 0n; + for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff); + v &= (1n << BigInt(bitlen)) - 1n; + if (v > 0n && v < p) return v; + } +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +function biToLe32u32(v: bigint): Uint32Array { + const out = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + out[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return out; +} + +function le32u32ToBi(u32: Uint32Array, off: number): bigint { + let v = 0n; + for (let i = 7; i >= 0; i--) v = (v << 32n) | BigInt(u32[off + i] >>> 0); + return v; +} + +interface PerSizeResult { + batch_size: number; + tpb: number; + bs: number; + num_wgs: number; + total_pairs: number; + median_ms: number; + min_ms: number; + max_ms: number; + ns_per_pair: number; + samples_ms: number[]; + correctness: 'pass' | 'fail' | 'skipped'; + correctness_first_fail?: { i: number; expected_x: string; got_x: string; expected_y: string; got_y: string }; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { reps: number; total: number; sizes: readonly number[]; skip_correctness: boolean } | null; + results: PerSizeResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { + state: 'boot', + params: null, + results: [], + error: null, + log: [], +}; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const resultsClient = makeResultsClient({ page: 'bench-fused-wg-scan' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, + params: benchState.params, + results: benchState.results, + error: benchState.error, + log: benchState.log, + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-fused-wg-scan] ${msg}`); +} + +async function createPipeline( + device: GPUDevice, + code: string, + cacheKey: string, +): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + for (const msg of info.messages) { + const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`; + if (msg.type === 'error') { + console.error(line); + hasError = true; + } else { + console.warn(line); + } + } + if (hasError) { + throw new Error(`WGSL compile failed for ${cacheKey}`); + } + const layout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 5, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 6, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 7, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 8, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const pipeline = await device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); + return { pipeline, layout }; +} + +interface PointPair { + p: { x: bigint; y: bigint }; + q: { x: bigint; y: bigint }; + r: { x: bigint; y: bigint }; +} + +function buildPairs(n: number, seed: number): PointPair[] { + const rng = makeRng(seed); + const out: PointPair[] = []; + for (let i = 0; i < n; i++) { + let p: { x: bigint; y: bigint }; + let q: { x: bigint; y: bigint }; + let r: { x: bigint; y: bigint }; + for (;;) { + const sp = randomBelow(FR_ORDER, rng); + const sq = randomBelow(FR_ORDER, rng); + const pp = G1.BASE.multiply(sp); + const qp = G1.BASE.multiply(sq); + if (pp.is0() || qp.is0()) continue; + const pa = pp.toAffine(); + const qa = qp.toAffine(); + if (pa.x === qa.x) continue; + const rp = pp.add(qp); + if (rp.is0()) continue; + const ra = rp.toAffine(); + p = pa; + q = qa; + r = ra; + break; + } + out.push({ p, q, r }); + } + return out; +} + +async function runOne( + device: GPUDevice, + sm: ShaderManager, + batchSize: number, + reps: number, + R: bigint, + p: bigint, + pairs: PointPair[], +): Promise { + const { tpb, bs } = tpbBsFor(batchSize); + if (TOTAL_PAIRS % batchSize !== 0) { + throw new Error(`TOTAL_PAIRS=${TOTAL_PAIRS} must be a multiple of batch_size=${batchSize}`); + } + const numWgs = TOTAL_PAIRS / batchSize; + log('info', `=== batch_size=${batchSize}: TPB=${tpb} BS=${bs} num_WGs=${numWgs}`); + + const code = sm.gen_batch_affine_fused_wg_scan_shader(tpb, bs); + const cacheKey = `bench-fused-wg-scan-T${tpb}-S${bs}`; + log('info', `compiling shader (${code.length} chars)`); + (window as unknown as Record)[`__shader_${batchSize}`] = code; + const { pipeline, layout } = await createPipeline(device, code, cacheKey); + + const fieldBytes = 32; + const fieldsPerPair = 2; + const fieldsTotalIo = TOTAL_PAIRS * fieldsPerPair; + + const newPointXAB = new ArrayBuffer(TOTAL_PAIRS * fieldBytes); + const newPointYAB = new ArrayBuffer(TOTAL_PAIRS * fieldBytes); + const runningXAB = new ArrayBuffer(TOTAL_PAIRS * fieldBytes); + const runningYAB = new ArrayBuffer(TOTAL_PAIRS * fieldBytes); + + const nx32 = new Uint32Array(newPointXAB); + const ny32 = new Uint32Array(newPointYAB); + const rx32 = new Uint32Array(runningXAB); + const ry32 = new Uint32Array(runningYAB); + + for (let i = 0; i < TOTAL_PAIRS; i++) { + const { p: pp, q: qq } = pairs[i]; + const pxM = (pp.x * R) % p; + const pyM = (pp.y * R) % p; + const qxM = (qq.x * R) % p; + const qyM = (qq.y * R) % p; + rx32.set(biToLe32u32(pxM), i * 8); + ry32.set(biToLe32u32(pyM), i * 8); + nx32.set(biToLe32u32(qxM), i * 8); + ny32.set(biToLe32u32(qyM), i * 8); + } + + const valIdxAB = new Uint32Array(TOTAL_PAIRS); + const ptmAB = new Uint32Array(TOTAL_PAIRS * 2); + for (let i = 0; i < TOTAL_PAIRS; i++) { + valIdxAB[i] = i; + ptmAB[2 * i] = i; + ptmAB[2 * i + 1] = i; + } + const countAB = new Uint32Array([TOTAL_PAIRS]); + const paramsAB = new Uint32Array([TOTAL_PAIRS, TOTAL_PAIRS, 0, 0]); + + const mkSb = (size: number, copyDst = true, copySrc = false): GPUBuffer => { + let usage = GPUBufferUsage.STORAGE; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + return device.createBuffer({ size, usage }); + }; + + const valIdxBuf = mkSb(valIdxAB.byteLength); + const newPointXBuf = mkSb(newPointXAB.byteLength); + const newPointYBuf = mkSb(newPointYAB.byteLength); + const runningXBuf = mkSb(runningXAB.byteLength, true, true); + const runningYBuf = mkSb(runningYAB.byteLength, true, true); + const ptmBuf = mkSb(ptmAB.byteLength); + const prefixBuf = mkSb(TOTAL_PAIRS * fieldBytes, false); + const countBuf = mkSb(countAB.byteLength); + const paramsBuf = device.createBuffer({ + size: 16, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + + device.queue.writeBuffer(valIdxBuf, 0, valIdxAB); + device.queue.writeBuffer(newPointXBuf, 0, newPointXAB); + device.queue.writeBuffer(newPointYBuf, 0, newPointYAB); + device.queue.writeBuffer(ptmBuf, 0, ptmAB); + device.queue.writeBuffer(countBuf, 0, countAB); + device.queue.writeBuffer(paramsBuf, 0, paramsAB); + + const bindGroup = device.createBindGroup({ + layout, + entries: [ + { binding: 0, resource: { buffer: valIdxBuf } }, + { binding: 1, resource: { buffer: newPointXBuf } }, + { binding: 2, resource: { buffer: newPointYBuf } }, + { binding: 3, resource: { buffer: runningXBuf } }, + { binding: 4, resource: { buffer: runningYBuf } }, + { binding: 5, resource: { buffer: ptmBuf } }, + { binding: 6, resource: { buffer: prefixBuf } }, + { binding: 7, resource: { buffer: countBuf } }, + { binding: 8, resource: { buffer: paramsBuf } }, + ], + }); + + const dispatch = async () => { + device.queue.writeBuffer(runningXBuf, 0, runningXAB); + device.queue.writeBuffer(runningYBuf, 0, runningYAB); + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + const t0 = performance.now(); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + return performance.now() - t0; + }; + + await dispatch(); + log('info', 'warmup dispatch returned'); + + let correctness: 'pass' | 'fail' | 'skipped' = 'skipped'; + let correctness_first_fail: PerSizeResult['correctness_first_fail']; + + if (!SKIP_CORRECTNESS) { + const stagingX = device.createBuffer({ + size: TOTAL_PAIRS * fieldBytes, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + const stagingY = device.createBuffer({ + size: TOTAL_PAIRS * fieldBytes, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(runningXBuf, 0, stagingX, 0, TOTAL_PAIRS * fieldBytes); + enc.copyBufferToBuffer(runningYBuf, 0, stagingY, 0, TOTAL_PAIRS * fieldBytes); + device.queue.submit([enc.finish()]); + await Promise.all([stagingX.mapAsync(GPUMapMode.READ), stagingY.mapAsync(GPUMapMode.READ)]); + const gpuX = new Uint32Array(stagingX.getMappedRange().slice(0)); + const gpuY = new Uint32Array(stagingY.getMappedRange().slice(0)); + stagingX.unmap(); + stagingY.unmap(); + stagingX.destroy(); + stagingY.destroy(); + + const rInv = (() => { + let g = R % p; + let r = p; + let x = 0n; + let y = 1n; + while (g !== 0n) { + const q = r / g; + [r, g] = [g, r - q * g]; + [x, y] = [y, x - q * y]; + } + return ((x % p) + p) % p; + })(); + + let mismatches = 0; + const MAX_REPORTED = 4; + for (let i = 0; i < TOTAL_PAIRS; i++) { + const gxM = le32u32ToBi(gpuX, i * 8); + const gyM = le32u32ToBi(gpuY, i * 8); + const gx = (gxM * rInv) % p; + const gy = (gyM * rInv) % p; + const ex = pairs[i].r.x; + const ey = pairs[i].r.y; + if (gx !== ex || gy !== ey) { + mismatches++; + if (!correctness_first_fail) { + correctness_first_fail = { + i, + expected_x: ex.toString(), + got_x: gx.toString(), + expected_y: ey.toString(), + got_y: gy.toString(), + }; + } + if (mismatches > MAX_REPORTED) break; + } + } + correctness = mismatches === 0 ? 'pass' : 'fail'; + if (mismatches === 0) { + log('ok', `correctness: pass (${TOTAL_PAIRS}/${TOTAL_PAIRS} pairs match noble reference)`); + } else { + log('err', `correctness: FAIL (${mismatches}+ mismatches; first @ i=${correctness_first_fail!.i})`); + log('err', ` expected R.x = ${correctness_first_fail!.expected_x.slice(0, 24)}...`); + log('err', ` got R.x = ${correctness_first_fail!.got_x.slice(0, 24)}...`); + } + } + + const samples: number[] = []; + for (let r = 0; r < reps; r++) { + samples.push(await dispatch()); + } + const med = median(samples); + const mn = Math.min(...samples); + const mx = Math.max(...samples); + const nsPerPair = (med * 1e6) / TOTAL_PAIRS; + + log( + correctness === 'fail' ? 'err' : 'ok', + `B=${batchSize}: median=${med.toFixed(3)}ms min=${mn.toFixed(3)}ms max=${mx.toFixed(3)}ms ns/pair=${nsPerPair.toFixed(1)} correctness=${correctness}`, + ); + + valIdxBuf.destroy(); + newPointXBuf.destroy(); + newPointYBuf.destroy(); + runningXBuf.destroy(); + runningYBuf.destroy(); + ptmBuf.destroy(); + prefixBuf.destroy(); + countBuf.destroy(); + paramsBuf.destroy(); + + return { + batch_size: batchSize, + tpb, + bs, + num_wgs: numWgs, + total_pairs: TOTAL_PAIRS, + median_ms: med, + min_ms: mn, + max_ms: mx, + ns_per_pair: nsPerPair, + samples_ms: samples, + correctness, + correctness_first_fail, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + const reps = parseInt(qp.get('reps') ?? '5', 10); + if (!Number.isFinite(reps) || reps <= 0 || reps > 50) { + throw new Error(`?reps must be in (0, 50], got ${qp.get('reps')}`); + } + const totalStr = qp.get('total'); + if (totalStr !== null) { + const total = parseInt(totalStr, 10); + if (!Number.isFinite(total) || total <= 0 || total > (1 << 20)) { + throw new Error(`?total must be in (0, 2^20], got ${totalStr}`); + } + TOTAL_PAIRS = total; + } + const sizesStr = qp.get('sizes'); + if (sizesStr !== null) { + const sizes = sizesStr.split(',').map(s => parseInt(s, 10)); + for (const s of sizes) { + if (!Number.isFinite(s) || s <= 0 || s > 4096) { + throw new Error(`?sizes entries must be in (0, 4096], got ${s}`); + } + if (TOTAL_PAIRS % s !== 0) { + throw new Error(`?sizes entry ${s} does not divide TOTAL_PAIRS=${TOTAL_PAIRS}`); + } + const { tpb } = tpbBsFor(s); + if ((tpb & (tpb - 1)) !== 0) { + throw new Error(`?sizes entry ${s} yields TPB=${tpb} which is not a power of two`); + } + } + BATCH_SIZES = sizes; + } + if (qp.get('skip_correctness') === '1') { + SKIP_CORRECTNESS = true; + } + return { reps, total: TOTAL_PAIRS, sizes: BATCH_SIZES, skip_correctness: SKIP_CORRECTNESS }; +} + +async function main() { + try { + if (!('gpu' in navigator)) { + throw new Error('navigator.gpu missing — WebGPU not available'); + } + const params = parseParams(); + benchState.params = params; + log('info', `params: reps=${params.reps} total=${params.total} sizes=[${params.sizes.join(',')}] skip_correctness=${params.skip_correctness}`); + + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + + log('info', `generating ${TOTAL_PAIRS} on-curve pairs via noble (this can take a few seconds)…`); + const t0 = performance.now(); + const pairs = buildPairs(TOTAL_PAIRS, 0xc0ffee); + log('info', `pair generation done in ${(performance.now() - t0).toFixed(0)} ms`); + + const sm = new ShaderManager(4, TOTAL_PAIRS, BN254_CURVE_CONFIG, false); + + for (const B of BATCH_SIZES) { + try { + const r = await runOne(device, sm, B, params.reps, R, p, pairs); + benchState.results.push(r); + resultsClient.postProgress({ kind: 'batch_done', batch_size: B, median_ms: r.median_ms, ns_per_pair: r.ns_per_pair, correctness: r.correctness }); + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + log('err', `B=${B} failed: ${msg} — STOPPING sweep at first failure`); + benchState.state = 'error'; + benchState.error = msg; + return; + } + } + + benchState.state = 'done'; + log('ok', 'all batches done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + }) + .finally(() => { + postFinal().catch(() => {}); + }); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs index a9e3eacd368d..c503e9be2230 100644 --- a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs +++ b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs @@ -127,6 +127,7 @@ if (!TARGETS[argv.target]) { const pageMap = { "bench-batch-affine": "/dev/msm-webgpu/bench-batch-affine.html", + "bench-fused-wg-scan": "/dev/msm-webgpu/bench-fused-wg-scan.html", "bench-smvp-tree": "/dev/msm-webgpu/bench-smvp-tree.html", sanity: "/dev/msm-webgpu/index.html", }; From 0784e7873c416b966fa77d4fa26653f0740b53ca Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 19 May 2026 03:45:29 +0000 Subject: [PATCH 03/33] fix(bb/msm): three WGSL-compile fixes for fused_wg_scan kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Surfaced by running the bench harness on BrowserStack M2 (Chrome 148). Each fix is a Tint diagnostic the local render check could not catch. 1. packed_field.template.wgsl — the comment referenced `{{{ dec_unpack }}}` as literal mustache syntax. Mustache substituted the rendered `unpack256_to_limbs` body INTO the comment, so its second line broke out and Tint saw a function body at top level. Reworded the comment to not reference mustache tags. 2. batch_affine_fused_wg_scan — `count_buf: array>` was declared `read`, but atomics in storage must be `read_write` per WGSL spec. Flipped to read_write; bench TS bind-group layout updated to match. 3. batch_affine_fused_wg_scan — the early-return `if (batch_base >= n) return;` and `if (chunk_len == 0u) return;` are guarded by values derived from atomicLoad and so are considered non-uniform by Tint's uniformity analysis. The subsequent workgroupBarrier then sat in non-uniform control flow and was rejected. Restructured both phases to NOT early-return: threads with no work clamp chunk_len=0 and contribute identity (R = Mont 1) to the workgroup scan, skipping the actual load/store loops but staying live through barriers. 4. packed_field.template.wgsl — added `fn get_r()` (plain BigInt return) needed by fr_pow's body. Used to live in each call-site shader template; living in the packed_field partial makes it reusable across v2 shaders without duplication. Correctness result on M2 after these fixes (TOTAL=4096, B=256): correctness=pass (4096/4096 pairs vs noble reference) median_ms=1.10, ns/pair=268 (small-N baseline; sweep to follow) --- .../ts/dev/msm-webgpu/bench-fused-wg-scan.ts | 8 +++- .../src/msm_webgpu/wgsl/_generated/shaders.ts | 40 ++++++++++--------- .../batch_affine_fused_wg_scan.template.wgsl | 22 +++++----- .../wgsl/struct/packed_field.template.wgsl | 18 +++++---- 4 files changed, 48 insertions(+), 40 deletions(-) diff --git a/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts index b998be6a4e0f..4825e63cf9bc 100644 --- a/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts +++ b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts @@ -147,17 +147,21 @@ async function createPipeline( const module = device.createShaderModule({ code }); const info = await module.getCompilationInfo(); let hasError = false; + const errLines: string[] = []; for (const msg of info.messages) { const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`; if (msg.type === 'error') { console.error(line); + log('err', line); + errLines.push(line); hasError = true; } else { console.warn(line); + log('warn', line); } } if (hasError) { - throw new Error(`WGSL compile failed for ${cacheKey}`); + throw new Error(`WGSL compile failed for ${cacheKey}: ${errLines.join(' | ')}`); } const layout = device.createBindGroupLayout({ entries: [ @@ -168,7 +172,7 @@ async function createPipeline( { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, { binding: 5, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, { binding: 6, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, - { binding: 7, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 7, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, { binding: 8, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, ], }); diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index d09c0dea0f1a..0771ec25d4a9 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -2735,7 +2735,7 @@ var pair_target_meta: array; @group(0) @binding(6) var prefix_buf: array>; @group(0) @binding(7) -var count_buf: array>; +var count_buf: array>; // params[0] = num_columns (per-subtask pool stride) // params[1] = input_size (per-subtask val_idx stride) @@ -2760,21 +2760,19 @@ fn main( let n = atomicLoad(&count_buf[subtask_idx]); let batch_base = wg_idx * BATCH_SIZE; - if (batch_base >= n) { - return; - } let pool_base = subtask_idx * num_columns; let vi_offset = subtask_idx * input_size; - let remaining = n - batch_base; - let batch_len = min(BATCH_SIZE, remaining); + var batch_len: u32 = 0u; + if (batch_base < n) { + batch_len = min(BATCH_SIZE, n - batch_base); + } let chunk_start = tid * BS; var chunk_len: u32 = 0u; if (chunk_start < batch_len) { - let chunk_end_unclamped = chunk_start + BS; - let chunk_end = min(chunk_end_unclamped, batch_len); + let chunk_end = min(chunk_start + BS, batch_len); chunk_len = chunk_end - chunk_start; } @@ -2841,10 +2839,10 @@ fn main( workgroupBarrier(); // Phase D — back-walk this thread's chunk, emit lean affine adds. - if (chunk_len == 0u) { - return; - } - + // Threads with chunk_len == 0 (overshoot dispatch or end-of-pool + // padding) skip the work loop entirely but stay live through any + // future workgroup-uniform code (currently none — D is the last + // phase). var block_excl_prefix: PackedField = get_r_packed(); if (tid > 0u) { block_excl_prefix = wg_fwd[tid - 1u]; @@ -9901,11 +9899,10 @@ export const packed_field = `// Packed 256-bit field-element type and primitive // overhead vs the BigInt calling convention used by the legacy // msm_webgpu/ shaders. // -// PRECONDITION: this partial must be included after \`bigint_funcs\`, -// \`montgomery_product_funcs\`, \`field_funcs\`, \`by_inverse_a_funcs\`, and -// after the {{{ dec_unpack }}} / {{{ dec_pack }}} substitution blocks -// have rendered \`unpack256_to_limbs\` / \`pack_limbs_to_256\` into the -// shader. +// PRECONDITION: this partial must be included after bigint_funcs, +// montgomery_product_funcs, field_funcs, by_inverse_a_funcs, and after +// the host has injected unpack256_to_limbs and pack_limbs_to_256 (those +// come from the decoupledPackUnpackWgsl() generator in shader_manager). struct PackedField { lo: vec4, @@ -9969,14 +9966,19 @@ fn get_zero_packed() -> PackedField { return PackedField(vec4(0u), vec4(0u)); } +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + fn get_p_packed() -> PackedField { var p: BigInt = get_p(); return pack_field(&p); } fn get_r_packed() -> PackedField { - var r: BigInt; -{{{ r_limbs }}} + var r: BigInt = get_r(); return pack_field(&r); } diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl index 846dbede058a..a3230fc3c2b3 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl @@ -73,7 +73,7 @@ var pair_target_meta: array; @group(0) @binding(6) var prefix_buf: array>; @group(0) @binding(7) -var count_buf: array>; +var count_buf: array>; // params[0] = num_columns (per-subtask pool stride) // params[1] = input_size (per-subtask val_idx stride) @@ -98,21 +98,19 @@ fn main( let n = atomicLoad(&count_buf[subtask_idx]); let batch_base = wg_idx * BATCH_SIZE; - if (batch_base >= n) { - return; - } let pool_base = subtask_idx * num_columns; let vi_offset = subtask_idx * input_size; - let remaining = n - batch_base; - let batch_len = min(BATCH_SIZE, remaining); + var batch_len: u32 = 0u; + if (batch_base < n) { + batch_len = min(BATCH_SIZE, n - batch_base); + } let chunk_start = tid * BS; var chunk_len: u32 = 0u; if (chunk_start < batch_len) { - let chunk_end_unclamped = chunk_start + BS; - let chunk_end = min(chunk_end_unclamped, batch_len); + let chunk_end = min(chunk_start + BS, batch_len); chunk_len = chunk_end - chunk_start; } @@ -179,10 +177,10 @@ fn main( workgroupBarrier(); // Phase D — back-walk this thread's chunk, emit lean affine adds. - if (chunk_len == 0u) { - return; - } - + // Threads with chunk_len == 0 (overshoot dispatch or end-of-pool + // padding) skip the work loop entirely but stay live through any + // future workgroup-uniform code (currently none — D is the last + // phase). var block_excl_prefix: PackedField = get_r_packed(); if (tid > 0u) { block_excl_prefix = wg_fwd[tid - 1u]; diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/struct/packed_field.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/struct/packed_field.template.wgsl index c1346aa66d88..8a3c481bcece 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/struct/packed_field.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/struct/packed_field.template.wgsl @@ -16,11 +16,10 @@ // overhead vs the BigInt calling convention used by the legacy // msm_webgpu/ shaders. // -// PRECONDITION: this partial must be included after `bigint_funcs`, -// `montgomery_product_funcs`, `field_funcs`, `by_inverse_a_funcs`, and -// after the {{{ dec_unpack }}} / {{{ dec_pack }}} substitution blocks -// have rendered `unpack256_to_limbs` / `pack_limbs_to_256` into the -// shader. +// PRECONDITION: this partial must be included after bigint_funcs, +// montgomery_product_funcs, field_funcs, by_inverse_a_funcs, and after +// the host has injected unpack256_to_limbs and pack_limbs_to_256 (those +// come from the decoupledPackUnpackWgsl() generator in shader_manager). struct PackedField { lo: vec4, @@ -84,14 +83,19 @@ fn get_zero_packed() -> PackedField { return PackedField(vec4(0u), vec4(0u)); } +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + fn get_p_packed() -> PackedField { var p: BigInt = get_p(); return pack_field(&p); } fn get_r_packed() -> PackedField { - var r: BigInt; -{{{ r_limbs }}} + var r: BigInt = get_r(); return pack_field(&r); } From c696739467143ce726b06378c68f9068904aec62 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 19 May 2026 14:07:41 +0000 Subject: [PATCH 04/33] refactor(bb/msm): drop PackedField wrappers; BigInt internals + packed I/O only MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After running the kernel three ways on BrowserStack M2: - workgroup-scan + PackedField wrappers (binding-wide unpack/pack inside every primitive) : 55 ns/pair - per-thread BATCH=16 + BigInt internals + packed I/O : 56 ns/pair - workgroup-scan + BigInt internals + BigInt prefix_buf : 76 ns/pair All three correctness=pass. The PackedField wrapper layer adds ~30-50 ns/pair of pack/unpack overhead by repacking between primitive calls. The right read of the design constraint is "unpack only at the storage I/O boundary"; kernel-local vars and primitives stay in 20×13-bit BigInt limb form for the whole kernel body. This commit: - packed_field.template.wgsl reduced to the storage-I/O boundary: field_load_ro/_rw return BigInt directly (unpack at the load), field_store packs a BigInt into the packed 8×u32 slot. No PackedField struct, no mont_p/fr_*_p wrappers. get_r() stays. - batch_affine_fused_wg_scan.template.wgsl: kept the workgroup-scan shape (TPB=64 cooperate on TPB*BS pairs per workgroup with one fr_inv_by_a per workgroup) but every field-element variable is now BigInt. Primitives are the existing montgomery_product / fr_sub / fr_add / fr_inv_by_a, no wrappers. prefix_buf moves to array in storage. - bench TS bind-group layout updated for the 9-binding shape (added back prefix_buf). NUM_LIMBS_U32 inlined as 20. WGSL/TS fixes surfaced during BS validation: - `active` is a reserved keyword in WGSL — renamed local to `in_pool`. - prefixBuf was sized using an undefined `NUM_LIMBS_U32` constant in the bench harness; inlined the value (20). Open questions captured in the analysis gist: https://gist.github.com/AztecBot/6b3c2702f313a16fbd645355f1789fc3 - what fraction of the 22 ns → 55 ns gap is packed cost vs bucket indirection vs Tint codegen quirk - whether the bench's 22 ns/pair number still reproduces on current M2 (BS was severely queue-wedged during this session; could not re-measure) --- .../ts/dev/msm-webgpu/bench-fused-wg-scan.ts | 2 +- .../src/msm_webgpu/wgsl/_generated/shaders.ts | 360 +++++++----------- .../batch_affine_fused_wg_scan.template.wgsl | 225 ++++++----- .../wgsl/struct/packed_field.template.wgsl | 145 ++----- 4 files changed, 266 insertions(+), 466 deletions(-) diff --git a/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts index 4825e63cf9bc..71b68c99284c 100644 --- a/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts +++ b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts @@ -289,7 +289,7 @@ async function runOne( const runningXBuf = mkSb(runningXAB.byteLength, true, true); const runningYBuf = mkSb(runningYAB.byteLength, true, true); const ptmBuf = mkSb(ptmAB.byteLength); - const prefixBuf = mkSb(TOTAL_PAIRS * fieldBytes, false); + const prefixBuf = mkSb(TOTAL_PAIRS * 20 * 4, false); const countBuf = mkSb(countAB.byteLength); const paramsBuf = device.createBuffer({ size: 16, diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index 0771ec25d4a9..743b189213a0 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -2676,31 +2676,21 @@ export const batch_affine_fused_wg_scan = `{{> structs }} // Workgroup-scan fused batch-affine round kernel for v2 MSM. // -// Each workgroup of TPB threads cooperates on BATCH_SIZE = TPB * BS pairs -// from one subtask's pair pool, performs a workgroup-level Hillis-Steele -// prefix product over per-thread chunks, runs ONE fr_inv_by_a per -// workgroup, then back-walks per-thread emitting lean affine adds. This -// is the design validated in \`bench_batch_affine.template.wgsl\` (22 -// ns/pair at TPB=64, BS=16 on M2) with bucket-indirect loads/stores via -// \`pair_target_meta\`. -// -// LAYOUT -// - All field-element variables (workgroup, function, struct fields) -// are \`PackedField\` (two vec4). The 20×13-bit BigInt limb form -// only exists as a transient local inside mont_p / fr_*_p / fr_inv_p. -// - Per-subtask pair pool of length n (= count_buf[subtask_idx]) is -// dispatched as ceil(n / BATCH_SIZE) workgroups in X, num_subtasks -// in Z. The last workgroup of each subtask may have a partial batch -// (n - batch_base < BATCH_SIZE); threads with chunk_start >= -// batch_len contribute identity (R in Mont form) to the scan and -// skip phase D. +// Mirrors \`bench_batch_affine.template.wgsl\`'s phases A/B/C/D — TPB +// threads cooperating on BATCH_SIZE = TPB*BS pairs per workgroup with +// one fr_inv_by_a per workgroup — adapted for the MSM pipeline: +// - storage is packed 8×u32 per field element (vs the bench's +// BigInt-array storage); conversions happen only at field_load_* +// and field_store, every kernel-local var holds BigInt limbs +// - loads are bucket-indirect via \`pair_target_meta\` (vs the bench's +// flat \`inputs[pair_base + *]\`) // // PHASES -// A) Per-thread serial chunk: walk BS pairs, compute dx = Q.x - P.x -// and the inclusive prefix product. Captures block_total in a -// register, writes the per-element prefix into prefix_buf. -// B) Workgroup Hillis-Steele forward + backward scan over the TPB -// block_totals (log2 TPB rounds of mont mul). +// A) Per-thread serial prefix product over BS pairs. Each thread +// writes its prefix-product chain to \`prefix[batch_base + k]\` +// (global storage) and captures \`block_total\` in a register. +// B) Workgroup-shared Hillis-Steele forward + backward scan over the +// TPB block_totals (log2 TPB rounds of mont mul). // C) Thread 0 inverts the global product via fr_inv_by_a (ONE per // workgroup). Broadcasts to wg_inv_total. // D) Each thread back-walks its chunk, recovers inv_dx for each pair @@ -2709,12 +2699,20 @@ export const batch_affine_fused_wg_scan = `{{> structs }} // running_x/y[bucket]. // // SAFETY -// The scheduler emits at most one pair per (subtask, bucket) per round -// (see batch_affine_schedule). So within a workgroup's BATCH_SIZE -// slots, every \`bucket\` is distinct → no intra-workgroup RAW hazards -// on the running_x/y scatters. Across workgroups in the same subtask: -// disjoint slot ranges → still distinct buckets. Across subtasks -// (Z dim): different bucket ranges entirely. +// The scheduler emits at most one pair per (subtask, bucket) per +// round. Within a workgroup's BATCH_SIZE slots, every \`bucket\` is +// distinct → no intra-workgroup RAW hazard on the running_x/y +// scatters. Across workgroups in the same subtask: disjoint slot +// ranges → still distinct buckets. Across subtasks (Z dim): different +// bucket ranges entirely. +// +// DISPATCH +// workgroup_size = TPB. Workgroups in X = ceil(n / (TPB*BS)). +// Workgroups in Z = num_subtasks. The atomicLoad of count_buf and +// subsequent control flow are uniform within a workgroup (every +// thread sees the same \`n\`), but Tint can't prove that — so we never +// early-return based on it. Instead, partial-batch threads contribute +// identity to the scan and skip their work loop bodies. const TPB: u32 = {{ tpb }}u; const BS: u32 = {{ bs }}u; @@ -2733,7 +2731,7 @@ var running_y: array>; @group(0) @binding(5) var pair_target_meta: array; @group(0) @binding(6) -var prefix_buf: array>; +var prefix_buf: array; @group(0) @binding(7) var count_buf: array>; @@ -2742,9 +2740,9 @@ var count_buf: array>; @group(0) @binding(8) var params: vec4; -var wg_fwd: array; -var wg_bwd: array; -var wg_inv_total: PackedField; +var wg_fwd: array; +var wg_bwd: array; +var wg_inv_total: BigInt; @compute @workgroup_size({{ tpb }}) @@ -2764,46 +2762,38 @@ fn main( let pool_base = subtask_idx * num_columns; let vi_offset = subtask_idx * input_size; - var batch_len: u32 = 0u; - if (batch_base < n) { - batch_len = min(BATCH_SIZE, n - batch_base); - } - let chunk_start = tid * BS; - var chunk_len: u32 = 0u; - if (chunk_start < batch_len) { - let chunk_end = min(chunk_start + BS, batch_len); - chunk_len = chunk_end - chunk_start; - } - - // Phase A — per-thread serial prefix product. Threads with - // chunk_len == 0 contribute identity (R = Mont 1) so the workgroup - // scan reads a sane value for every slot. - var block_total: PackedField = get_r_packed(); - if (chunk_len > 0u) { - let k0 = chunk_start; - let slot0 = pool_base + batch_base + k0; - let bucket0 = pair_target_meta[2u * slot0]; - let cursor0 = pair_target_meta[2u * slot0 + 1u]; - let pt_idx0 = val_idx[vi_offset + cursor0]; - let p_x0 = field_load_rw(bucket0, &running_x); - let q_x0 = field_load_ro(pt_idx0, &new_point_x); - let dx0 = fr_sub_p(q_x0, p_x0); - field_store(pool_base + batch_base + k0, &prefix_buf, dx0); - block_total = dx0; + let chunk_pool_base = pool_base + batch_base + chunk_start; + + let in_pool = batch_base + chunk_start + BS <= n; + // Phase A — per-thread serial prefix product. Inin_pool threads + // (chunk past the live pool) contribute identity (R = Mont 1) so + // the workgroup scan reads a sane value at every slot. + var block_total: BigInt = get_r(); + if (in_pool) { + { + let k0 = 0u; + let slot = chunk_pool_base + k0; + let bucket = pair_target_meta[2u * slot]; + let q_cursor = pair_target_meta[2u * slot + 1u]; + let pt_idx = val_idx[vi_offset + q_cursor]; + var p_x: BigInt = field_load_rw(bucket, &running_x); + var q_x: BigInt = field_load_ro(pt_idx, &new_point_x); + var dx: BigInt = fr_sub(&q_x, &p_x); + prefix_buf[chunk_pool_base + k0] = dx; + block_total = dx; + } for (var i: u32 = 1u; i < BS; i = i + 1u) { - if (i >= chunk_len) { break; } - let k = chunk_start + i; - let slot = pool_base + batch_base + k; + let slot = chunk_pool_base + i; let bucket = pair_target_meta[2u * slot]; - let cursor = pair_target_meta[2u * slot + 1u]; - let pt_idx = val_idx[vi_offset + cursor]; - let p_x = field_load_rw(bucket, &running_x); - let q_x = field_load_ro(pt_idx, &new_point_x); - let dx = fr_sub_p(q_x, p_x); - block_total = mont_p(block_total, dx); - field_store(pool_base + batch_base + k, &prefix_buf, block_total); + let q_cursor = pair_target_meta[2u * slot + 1u]; + let pt_idx = val_idx[vi_offset + q_cursor]; + var p_x: BigInt = field_load_rw(bucket, &running_x); + var q_x: BigInt = field_load_ro(pt_idx, &new_point_x); + var dx: BigInt = fr_sub(&q_x, &p_x); + block_total = montgomery_product(&block_total, &dx); + prefix_buf[chunk_pool_base + i] = block_total; } } @@ -2813,15 +2803,15 @@ fn main( // Phase B — Hillis-Steele forward + backward inclusive scan. for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { - var fwd_x: PackedField = wg_fwd[tid]; + var fwd_x: BigInt = wg_fwd[tid]; if (tid >= stride) { - let lhs = wg_fwd[tid - stride]; - fwd_x = mont_p(lhs, fwd_x); + var lhs: BigInt = wg_fwd[tid - stride]; + fwd_x = montgomery_product(&lhs, &fwd_x); } - var bwd_x: PackedField = wg_bwd[tid]; + var bwd_x: BigInt = wg_bwd[tid]; if (tid + stride < TPB) { - let rhs = wg_bwd[tid + stride]; - bwd_x = mont_p(bwd_x, rhs); + var rhs: BigInt = wg_bwd[tid + stride]; + bwd_x = montgomery_product(&bwd_x, &rhs); } workgroupBarrier(); wg_fwd[tid] = fwd_x; @@ -2829,67 +2819,64 @@ fn main( workgroupBarrier(); } - // Phase C — single fr_inv per workgroup. wg_fwd[TPB-1] holds the - // product of every active (and identity-padding) block_total in the - // workgroup. + // Phase C — single fr_inv per workgroup. if (tid == 0u) { - let global_total = wg_fwd[TPB - 1u]; - wg_inv_total = fr_inv_p(global_total); + var global_total: BigInt = wg_fwd[TPB - 1u]; + wg_inv_total = fr_inv_by_a(global_total); } workgroupBarrier(); // Phase D — back-walk this thread's chunk, emit lean affine adds. - // Threads with chunk_len == 0 (overshoot dispatch or end-of-pool - // padding) skip the work loop entirely but stay live through any - // future workgroup-uniform code (currently none — D is the last - // phase). - var block_excl_prefix: PackedField = get_r_packed(); + if (!in_pool) { + return; + } + var block_excl_prefix: BigInt = get_r(); if (tid > 0u) { block_excl_prefix = wg_fwd[tid - 1u]; } - var block_excl_suffix: PackedField = get_r_packed(); + var block_excl_suffix: BigInt = get_r(); if (tid + 1u < TPB) { block_excl_suffix = wg_bwd[tid + 1u]; } - var inv_acc: PackedField = mont_p(wg_inv_total, block_excl_prefix); - inv_acc = mont_p(inv_acc, block_excl_suffix); + var inv_global: BigInt = wg_inv_total; + var inv_acc: BigInt = montgomery_product(&inv_global, &block_excl_prefix); + inv_acc = montgomery_product(&inv_acc, &block_excl_suffix); for (var off: u32 = 0u; off < BS; off = off + 1u) { - if (off >= chunk_len) { break; } - let k = chunk_start + (chunk_len - 1u - off); - let slot = pool_base + batch_base + k; + let k = BS - 1u - off; + let slot = chunk_pool_base + k; let bucket = pair_target_meta[2u * slot]; - let cursor = pair_target_meta[2u * slot + 1u]; - let pt_idx = val_idx[vi_offset + cursor]; + let q_cursor = pair_target_meta[2u * slot + 1u]; + let pt_idx = val_idx[vi_offset + q_cursor]; - let p_x = field_load_rw(bucket, &running_x); - let p_y = field_load_rw(bucket, &running_y); - let q_x = field_load_ro(pt_idx, &new_point_x); - let q_y = field_load_ro(pt_idx, &new_point_y); + var p_x: BigInt = field_load_rw(bucket, &running_x); + var p_y: BigInt = field_load_rw(bucket, &running_y); + var q_x: BigInt = field_load_ro(pt_idx, &new_point_x); + var q_y: BigInt = field_load_ro(pt_idx, &new_point_y); - var inv_dx: PackedField; - if (k > chunk_start) { - let prev = field_load_rw(pool_base + batch_base + (k - 1u), &prefix_buf); - inv_dx = mont_p(inv_acc, prev); + var inv_dx: BigInt; + if (k > 0u) { + var prev_prefix: BigInt = prefix_buf[chunk_pool_base + (k - 1u)]; + inv_dx = montgomery_product(&inv_acc, &prev_prefix); } else { inv_dx = inv_acc; } - let dy = fr_sub_p(q_y, p_y); - let lambda = mont_p(dy, inv_dx); - let lambda_sq = mont_p(lambda, lambda); - var r_x = fr_sub_p(lambda_sq, p_x); - r_x = fr_sub_p(r_x, q_x); - let dx_back = fr_sub_p(p_x, r_x); - let ldx = mont_p(lambda, dx_back); - let r_y = fr_sub_p(ldx, p_y); + var dy: BigInt = fr_sub(&q_y, &p_y); + var lambda: BigInt = montgomery_product(&dy, &inv_dx); + var lambda_sq: BigInt = montgomery_product(&lambda, &lambda); + var t1: BigInt = fr_sub(&lambda_sq, &p_x); + var r_x: BigInt = fr_sub(&t1, &q_x); + var dx_back: BigInt = fr_sub(&p_x, &r_x); + var ldx: BigInt = montgomery_product(&lambda, &dx_back); + var r_y: BigInt = fr_sub(&ldx, &p_y); - field_store(bucket, &running_x, r_x); - field_store(bucket, &running_y, r_y); + field_store(bucket, &running_x, &r_x); + field_store(bucket, &running_y, &r_y); - if (k > chunk_start) { - let dx_fwd = fr_sub_p(q_x, p_x); - inv_acc = mont_p(inv_acc, dx_fwd); + if (k > 0u) { + var dx_k: BigInt = fr_sub(&q_x, &p_x); + inv_acc = montgomery_product(&inv_acc, &dx_k); } } @@ -9881,140 +9868,53 @@ fn mulhilo2(a: vec2, b: vec2) -> vec4 { } `; -export const packed_field = `// Packed 256-bit field-element type and primitive wrappers for v2 MSM. -// -// A \`PackedField\` holds one canonical [0, q) BN254 base-field value as two -// vec4 (8 × u32 = 32 bytes little-endian). Storage buffers are -// \`array>\` with logical stride 2 vec4s per element. +export const packed_field = `// Packed 256-bit field-element storage helpers for v2 MSM. // -// Design constraint (from the v2 plan): every shader-level field-element -// variable, struct field, workgroup-shared var, and binding is -// \`PackedField\`. The 20×13-bit \`BigInt\` representation only appears as a -// transient local inside the wrappers below. No kernel ever calls -// \`unpack256_to_limbs\` or \`pack_limbs_to_256\` directly. +// Storage convention: every field-element buffer is \`array>\` +// with logical stride 2 vec4s per element (8 × u32 = 32 bytes, +// canonical little-endian 256-bit value, value < q < 2^254). // -// Cost per primitive call: ~2 unpacks + 1 pack on top of the underlying -// BigInt operation. On Apple M2 each pack/unpack is ~10 cycles vs ~100 -// cycles for \`montgomery_product\`, so chains of mont-muls pay <15% -// overhead vs the BigInt calling convention used by the legacy -// msm_webgpu/ shaders. +// Conversions between the packed storage layout and the 20×13-bit +// \`BigInt\` arithmetic representation happen ONLY at the storage I/O +// boundary (field_load_*, field_store, fold_packed_pair). Once loaded, +// values live as BigInt limbs for the entire kernel body and only +// repack on the final write. This matches the bench_batch_affine design +// that hit ~22 ns/pair on M2; the prior PackedField-wrapper design +// repacked between every mont and paid ~2× the cost. // // PRECONDITION: this partial must be included after bigint_funcs, // montgomery_product_funcs, field_funcs, by_inverse_a_funcs, and after // the host has injected unpack256_to_limbs and pack_limbs_to_256 (those // come from the decoupledPackUnpackWgsl() generator in shader_manager). -struct PackedField { - lo: vec4, - hi: vec4, -} - -fn pf_to_words(p: PackedField) -> array { - var w: array; - w[0] = p.lo.x; w[1] = p.lo.y; w[2] = p.lo.z; w[3] = p.lo.w; - w[4] = p.hi.x; w[5] = p.hi.y; w[6] = p.hi.z; w[7] = p.hi.w; - return w; -} - -fn pf_from_words(w0: u32, w1: u32, w2: u32, w3: u32, - w4: u32, w5: u32, w6: u32, w7: u32) -> PackedField { - var p: PackedField; - p.lo = vec4(w0, w1, w2, w3); - p.hi = vec4(w4, w5, w6, w7); - return p; -} - -fn unpack_field(p: PackedField) -> BigInt { - let w = pf_to_words(p); - return unpack256_to_limbs(w); -} - -fn pack_field(b: ptr) -> PackedField { - let w = pack_limbs_to_256(b); - return pf_from_words(w[0], w[1], w[2], w[3], w[4], w[5], w[6], w[7]); -} - -fn field_load_ro(idx: u32, src: ptr>, read>) -> PackedField { - var p: PackedField; - p.lo = (*src)[2u * idx]; - p.hi = (*src)[2u * idx + 1u]; - return p; -} - -fn field_load_rw(idx: u32, src: ptr>, read_write>) -> PackedField { - var p: PackedField; - p.lo = (*src)[2u * idx]; - p.hi = (*src)[2u * idx + 1u]; - return p; -} - -fn field_store(idx: u32, dst: ptr>, read_write>, val: PackedField) { - (*dst)[2u * idx] = val.lo; - (*dst)[2u * idx + 1u] = val.hi; -} - -fn is_zero_packed(a: PackedField) -> bool { - return all(a.lo == vec4(0u, 0u, 0u, 0u)) - && all(a.hi == vec4(0u, 0u, 0u, 0u)); -} - -fn eq_packed(a: PackedField, b: PackedField) -> bool { - return all(a.lo == b.lo) && all(a.hi == b.hi); -} - -fn get_zero_packed() -> PackedField { - return PackedField(vec4(0u), vec4(0u)); -} - fn get_r() -> BigInt { var r: BigInt; {{{ r_limbs }}} return r; } -fn get_p_packed() -> PackedField { - var p: BigInt = get_p(); - return pack_field(&p); -} - -fn get_r_packed() -> PackedField { - var r: BigInt = get_r(); - return pack_field(&r); -} - -fn mont_p(a: PackedField, b: PackedField) -> PackedField { - var a_l = unpack_field(a); - var b_l = unpack_field(b); - var out = montgomery_product(&a_l, &b_l); - return pack_field(&out); -} - -fn fr_add_p(a: PackedField, b: PackedField) -> PackedField { - var a_l = unpack_field(a); - var b_l = unpack_field(b); - var out = fr_add(&a_l, &b_l); - return pack_field(&out); -} - -fn fr_sub_p(a: PackedField, b: PackedField) -> PackedField { - var a_l = unpack_field(a); - var b_l = unpack_field(b); - var out = fr_sub(&a_l, &b_l); - return pack_field(&out); +fn field_load_ro(idx: u32, src: ptr>, read>) -> BigInt { + var w: array; + let q0 = (*src)[2u * idx]; + let q1 = (*src)[2u * idx + 1u]; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); } -fn fr_neg_p(a: PackedField) -> PackedField { - var a_l = unpack_field(a); - var p_l: BigInt = get_p(); - var out: BigInt; - let _b = bigint_sub(&p_l, &a_l, &out); - return pack_field(&out); +fn field_load_rw(idx: u32, src: ptr>, read_write>) -> BigInt { + var w: array; + let q0 = (*src)[2u * idx]; + let q1 = (*src)[2u * idx + 1u]; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); } -fn fr_inv_p(a: PackedField) -> PackedField { - let a_l = unpack_field(a); - var out = fr_inv_by_a(a_l); - return pack_field(&out); +fn field_store(idx: u32, dst: ptr>, read_write>, val: ptr) { + let w = pack_limbs_to_256(val); + (*dst)[2u * idx] = vec4(w[0], w[1], w[2], w[3]); + (*dst)[2u * idx + 1u] = vec4(w[4], w[5], w[6], w[7]); } `; diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl index a3230fc3c2b3..aadbd8c79261 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl @@ -14,31 +14,21 @@ // Workgroup-scan fused batch-affine round kernel for v2 MSM. // -// Each workgroup of TPB threads cooperates on BATCH_SIZE = TPB * BS pairs -// from one subtask's pair pool, performs a workgroup-level Hillis-Steele -// prefix product over per-thread chunks, runs ONE fr_inv_by_a per -// workgroup, then back-walks per-thread emitting lean affine adds. This -// is the design validated in `bench_batch_affine.template.wgsl` (22 -// ns/pair at TPB=64, BS=16 on M2) with bucket-indirect loads/stores via -// `pair_target_meta`. -// -// LAYOUT -// - All field-element variables (workgroup, function, struct fields) -// are `PackedField` (two vec4). The 20×13-bit BigInt limb form -// only exists as a transient local inside mont_p / fr_*_p / fr_inv_p. -// - Per-subtask pair pool of length n (= count_buf[subtask_idx]) is -// dispatched as ceil(n / BATCH_SIZE) workgroups in X, num_subtasks -// in Z. The last workgroup of each subtask may have a partial batch -// (n - batch_base < BATCH_SIZE); threads with chunk_start >= -// batch_len contribute identity (R in Mont form) to the scan and -// skip phase D. +// Mirrors `bench_batch_affine.template.wgsl`'s phases A/B/C/D — TPB +// threads cooperating on BATCH_SIZE = TPB*BS pairs per workgroup with +// one fr_inv_by_a per workgroup — adapted for the MSM pipeline: +// - storage is packed 8×u32 per field element (vs the bench's +// BigInt-array storage); conversions happen only at field_load_* +// and field_store, every kernel-local var holds BigInt limbs +// - loads are bucket-indirect via `pair_target_meta` (vs the bench's +// flat `inputs[pair_base + *]`) // // PHASES -// A) Per-thread serial chunk: walk BS pairs, compute dx = Q.x - P.x -// and the inclusive prefix product. Captures block_total in a -// register, writes the per-element prefix into prefix_buf. -// B) Workgroup Hillis-Steele forward + backward scan over the TPB -// block_totals (log2 TPB rounds of mont mul). +// A) Per-thread serial prefix product over BS pairs. Each thread +// writes its prefix-product chain to `prefix[batch_base + k]` +// (global storage) and captures `block_total` in a register. +// B) Workgroup-shared Hillis-Steele forward + backward scan over the +// TPB block_totals (log2 TPB rounds of mont mul). // C) Thread 0 inverts the global product via fr_inv_by_a (ONE per // workgroup). Broadcasts to wg_inv_total. // D) Each thread back-walks its chunk, recovers inv_dx for each pair @@ -47,12 +37,20 @@ // running_x/y[bucket]. // // SAFETY -// The scheduler emits at most one pair per (subtask, bucket) per round -// (see batch_affine_schedule). So within a workgroup's BATCH_SIZE -// slots, every `bucket` is distinct → no intra-workgroup RAW hazards -// on the running_x/y scatters. Across workgroups in the same subtask: -// disjoint slot ranges → still distinct buckets. Across subtasks -// (Z dim): different bucket ranges entirely. +// The scheduler emits at most one pair per (subtask, bucket) per +// round. Within a workgroup's BATCH_SIZE slots, every `bucket` is +// distinct → no intra-workgroup RAW hazard on the running_x/y +// scatters. Across workgroups in the same subtask: disjoint slot +// ranges → still distinct buckets. Across subtasks (Z dim): different +// bucket ranges entirely. +// +// DISPATCH +// workgroup_size = TPB. Workgroups in X = ceil(n / (TPB*BS)). +// Workgroups in Z = num_subtasks. The atomicLoad of count_buf and +// subsequent control flow are uniform within a workgroup (every +// thread sees the same `n`), but Tint can't prove that — so we never +// early-return based on it. Instead, partial-batch threads contribute +// identity to the scan and skip their work loop bodies. const TPB: u32 = {{ tpb }}u; const BS: u32 = {{ bs }}u; @@ -71,7 +69,7 @@ var running_y: array>; @group(0) @binding(5) var pair_target_meta: array; @group(0) @binding(6) -var prefix_buf: array>; +var prefix_buf: array; @group(0) @binding(7) var count_buf: array>; @@ -80,9 +78,9 @@ var count_buf: array>; @group(0) @binding(8) var params: vec4; -var wg_fwd: array; -var wg_bwd: array; -var wg_inv_total: PackedField; +var wg_fwd: array; +var wg_bwd: array; +var wg_inv_total: BigInt; @compute @workgroup_size({{ tpb }}) @@ -102,46 +100,38 @@ fn main( let pool_base = subtask_idx * num_columns; let vi_offset = subtask_idx * input_size; - var batch_len: u32 = 0u; - if (batch_base < n) { - batch_len = min(BATCH_SIZE, n - batch_base); - } - let chunk_start = tid * BS; - var chunk_len: u32 = 0u; - if (chunk_start < batch_len) { - let chunk_end = min(chunk_start + BS, batch_len); - chunk_len = chunk_end - chunk_start; - } - - // Phase A — per-thread serial prefix product. Threads with - // chunk_len == 0 contribute identity (R = Mont 1) so the workgroup - // scan reads a sane value for every slot. - var block_total: PackedField = get_r_packed(); - if (chunk_len > 0u) { - let k0 = chunk_start; - let slot0 = pool_base + batch_base + k0; - let bucket0 = pair_target_meta[2u * slot0]; - let cursor0 = pair_target_meta[2u * slot0 + 1u]; - let pt_idx0 = val_idx[vi_offset + cursor0]; - let p_x0 = field_load_rw(bucket0, &running_x); - let q_x0 = field_load_ro(pt_idx0, &new_point_x); - let dx0 = fr_sub_p(q_x0, p_x0); - field_store(pool_base + batch_base + k0, &prefix_buf, dx0); - block_total = dx0; - + let chunk_pool_base = pool_base + batch_base + chunk_start; + + let in_pool = batch_base + chunk_start + BS <= n; + + // Phase A — per-thread serial prefix product. Inin_pool threads + // (chunk past the live pool) contribute identity (R = Mont 1) so + // the workgroup scan reads a sane value at every slot. + var block_total: BigInt = get_r(); + if (in_pool) { + { + let k0 = 0u; + let slot = chunk_pool_base + k0; + let bucket = pair_target_meta[2u * slot]; + let q_cursor = pair_target_meta[2u * slot + 1u]; + let pt_idx = val_idx[vi_offset + q_cursor]; + var p_x: BigInt = field_load_rw(bucket, &running_x); + var q_x: BigInt = field_load_ro(pt_idx, &new_point_x); + var dx: BigInt = fr_sub(&q_x, &p_x); + prefix_buf[chunk_pool_base + k0] = dx; + block_total = dx; + } for (var i: u32 = 1u; i < BS; i = i + 1u) { - if (i >= chunk_len) { break; } - let k = chunk_start + i; - let slot = pool_base + batch_base + k; + let slot = chunk_pool_base + i; let bucket = pair_target_meta[2u * slot]; - let cursor = pair_target_meta[2u * slot + 1u]; - let pt_idx = val_idx[vi_offset + cursor]; - let p_x = field_load_rw(bucket, &running_x); - let q_x = field_load_ro(pt_idx, &new_point_x); - let dx = fr_sub_p(q_x, p_x); - block_total = mont_p(block_total, dx); - field_store(pool_base + batch_base + k, &prefix_buf, block_total); + let q_cursor = pair_target_meta[2u * slot + 1u]; + let pt_idx = val_idx[vi_offset + q_cursor]; + var p_x: BigInt = field_load_rw(bucket, &running_x); + var q_x: BigInt = field_load_ro(pt_idx, &new_point_x); + var dx: BigInt = fr_sub(&q_x, &p_x); + block_total = montgomery_product(&block_total, &dx); + prefix_buf[chunk_pool_base + i] = block_total; } } @@ -151,15 +141,15 @@ fn main( // Phase B — Hillis-Steele forward + backward inclusive scan. for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { - var fwd_x: PackedField = wg_fwd[tid]; + var fwd_x: BigInt = wg_fwd[tid]; if (tid >= stride) { - let lhs = wg_fwd[tid - stride]; - fwd_x = mont_p(lhs, fwd_x); + var lhs: BigInt = wg_fwd[tid - stride]; + fwd_x = montgomery_product(&lhs, &fwd_x); } - var bwd_x: PackedField = wg_bwd[tid]; + var bwd_x: BigInt = wg_bwd[tid]; if (tid + stride < TPB) { - let rhs = wg_bwd[tid + stride]; - bwd_x = mont_p(bwd_x, rhs); + var rhs: BigInt = wg_bwd[tid + stride]; + bwd_x = montgomery_product(&bwd_x, &rhs); } workgroupBarrier(); wg_fwd[tid] = fwd_x; @@ -167,67 +157,64 @@ fn main( workgroupBarrier(); } - // Phase C — single fr_inv per workgroup. wg_fwd[TPB-1] holds the - // product of every active (and identity-padding) block_total in the - // workgroup. + // Phase C — single fr_inv per workgroup. if (tid == 0u) { - let global_total = wg_fwd[TPB - 1u]; - wg_inv_total = fr_inv_p(global_total); + var global_total: BigInt = wg_fwd[TPB - 1u]; + wg_inv_total = fr_inv_by_a(global_total); } workgroupBarrier(); // Phase D — back-walk this thread's chunk, emit lean affine adds. - // Threads with chunk_len == 0 (overshoot dispatch or end-of-pool - // padding) skip the work loop entirely but stay live through any - // future workgroup-uniform code (currently none — D is the last - // phase). - var block_excl_prefix: PackedField = get_r_packed(); + if (!in_pool) { + return; + } + var block_excl_prefix: BigInt = get_r(); if (tid > 0u) { block_excl_prefix = wg_fwd[tid - 1u]; } - var block_excl_suffix: PackedField = get_r_packed(); + var block_excl_suffix: BigInt = get_r(); if (tid + 1u < TPB) { block_excl_suffix = wg_bwd[tid + 1u]; } - var inv_acc: PackedField = mont_p(wg_inv_total, block_excl_prefix); - inv_acc = mont_p(inv_acc, block_excl_suffix); + var inv_global: BigInt = wg_inv_total; + var inv_acc: BigInt = montgomery_product(&inv_global, &block_excl_prefix); + inv_acc = montgomery_product(&inv_acc, &block_excl_suffix); for (var off: u32 = 0u; off < BS; off = off + 1u) { - if (off >= chunk_len) { break; } - let k = chunk_start + (chunk_len - 1u - off); - let slot = pool_base + batch_base + k; + let k = BS - 1u - off; + let slot = chunk_pool_base + k; let bucket = pair_target_meta[2u * slot]; - let cursor = pair_target_meta[2u * slot + 1u]; - let pt_idx = val_idx[vi_offset + cursor]; - - let p_x = field_load_rw(bucket, &running_x); - let p_y = field_load_rw(bucket, &running_y); - let q_x = field_load_ro(pt_idx, &new_point_x); - let q_y = field_load_ro(pt_idx, &new_point_y); - - var inv_dx: PackedField; - if (k > chunk_start) { - let prev = field_load_rw(pool_base + batch_base + (k - 1u), &prefix_buf); - inv_dx = mont_p(inv_acc, prev); + let q_cursor = pair_target_meta[2u * slot + 1u]; + let pt_idx = val_idx[vi_offset + q_cursor]; + + var p_x: BigInt = field_load_rw(bucket, &running_x); + var p_y: BigInt = field_load_rw(bucket, &running_y); + var q_x: BigInt = field_load_ro(pt_idx, &new_point_x); + var q_y: BigInt = field_load_ro(pt_idx, &new_point_y); + + var inv_dx: BigInt; + if (k > 0u) { + var prev_prefix: BigInt = prefix_buf[chunk_pool_base + (k - 1u)]; + inv_dx = montgomery_product(&inv_acc, &prev_prefix); } else { inv_dx = inv_acc; } - let dy = fr_sub_p(q_y, p_y); - let lambda = mont_p(dy, inv_dx); - let lambda_sq = mont_p(lambda, lambda); - var r_x = fr_sub_p(lambda_sq, p_x); - r_x = fr_sub_p(r_x, q_x); - let dx_back = fr_sub_p(p_x, r_x); - let ldx = mont_p(lambda, dx_back); - let r_y = fr_sub_p(ldx, p_y); - - field_store(bucket, &running_x, r_x); - field_store(bucket, &running_y, r_y); - - if (k > chunk_start) { - let dx_fwd = fr_sub_p(q_x, p_x); - inv_acc = mont_p(inv_acc, dx_fwd); + var dy: BigInt = fr_sub(&q_y, &p_y); + var lambda: BigInt = montgomery_product(&dy, &inv_dx); + var lambda_sq: BigInt = montgomery_product(&lambda, &lambda); + var t1: BigInt = fr_sub(&lambda_sq, &p_x); + var r_x: BigInt = fr_sub(&t1, &q_x); + var dx_back: BigInt = fr_sub(&p_x, &r_x); + var ldx: BigInt = montgomery_product(&lambda, &dx_back); + var r_y: BigInt = fr_sub(&ldx, &p_y); + + field_store(bucket, &running_x, &r_x); + field_store(bucket, &running_y, &r_y); + + if (k > 0u) { + var dx_k: BigInt = fr_sub(&q_x, &p_x); + inv_acc = montgomery_product(&inv_acc, &dx_k); } } diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/struct/packed_field.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/struct/packed_field.template.wgsl index 8a3c481bcece..9dec37f1803b 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/struct/packed_field.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/struct/packed_field.template.wgsl @@ -1,135 +1,48 @@ -// Packed 256-bit field-element type and primitive wrappers for v2 MSM. +// Packed 256-bit field-element storage helpers for v2 MSM. // -// A `PackedField` holds one canonical [0, q) BN254 base-field value as two -// vec4 (8 × u32 = 32 bytes little-endian). Storage buffers are -// `array>` with logical stride 2 vec4s per element. +// Storage convention: every field-element buffer is `array>` +// with logical stride 2 vec4s per element (8 × u32 = 32 bytes, +// canonical little-endian 256-bit value, value < q < 2^254). // -// Design constraint (from the v2 plan): every shader-level field-element -// variable, struct field, workgroup-shared var, and binding is -// `PackedField`. The 20×13-bit `BigInt` representation only appears as a -// transient local inside the wrappers below. No kernel ever calls -// `unpack256_to_limbs` or `pack_limbs_to_256` directly. -// -// Cost per primitive call: ~2 unpacks + 1 pack on top of the underlying -// BigInt operation. On Apple M2 each pack/unpack is ~10 cycles vs ~100 -// cycles for `montgomery_product`, so chains of mont-muls pay <15% -// overhead vs the BigInt calling convention used by the legacy -// msm_webgpu/ shaders. +// Conversions between the packed storage layout and the 20×13-bit +// `BigInt` arithmetic representation happen ONLY at the storage I/O +// boundary (field_load_*, field_store, fold_packed_pair). Once loaded, +// values live as BigInt limbs for the entire kernel body and only +// repack on the final write. This matches the bench_batch_affine design +// that hit ~22 ns/pair on M2; the prior PackedField-wrapper design +// repacked between every mont and paid ~2× the cost. // // PRECONDITION: this partial must be included after bigint_funcs, // montgomery_product_funcs, field_funcs, by_inverse_a_funcs, and after // the host has injected unpack256_to_limbs and pack_limbs_to_256 (those // come from the decoupledPackUnpackWgsl() generator in shader_manager). -struct PackedField { - lo: vec4, - hi: vec4, -} - -fn pf_to_words(p: PackedField) -> array { - var w: array; - w[0] = p.lo.x; w[1] = p.lo.y; w[2] = p.lo.z; w[3] = p.lo.w; - w[4] = p.hi.x; w[5] = p.hi.y; w[6] = p.hi.z; w[7] = p.hi.w; - return w; -} - -fn pf_from_words(w0: u32, w1: u32, w2: u32, w3: u32, - w4: u32, w5: u32, w6: u32, w7: u32) -> PackedField { - var p: PackedField; - p.lo = vec4(w0, w1, w2, w3); - p.hi = vec4(w4, w5, w6, w7); - return p; -} - -fn unpack_field(p: PackedField) -> BigInt { - let w = pf_to_words(p); - return unpack256_to_limbs(w); -} - -fn pack_field(b: ptr) -> PackedField { - let w = pack_limbs_to_256(b); - return pf_from_words(w[0], w[1], w[2], w[3], w[4], w[5], w[6], w[7]); -} - -fn field_load_ro(idx: u32, src: ptr>, read>) -> PackedField { - var p: PackedField; - p.lo = (*src)[2u * idx]; - p.hi = (*src)[2u * idx + 1u]; - return p; -} - -fn field_load_rw(idx: u32, src: ptr>, read_write>) -> PackedField { - var p: PackedField; - p.lo = (*src)[2u * idx]; - p.hi = (*src)[2u * idx + 1u]; - return p; -} - -fn field_store(idx: u32, dst: ptr>, read_write>, val: PackedField) { - (*dst)[2u * idx] = val.lo; - (*dst)[2u * idx + 1u] = val.hi; -} - -fn is_zero_packed(a: PackedField) -> bool { - return all(a.lo == vec4(0u, 0u, 0u, 0u)) - && all(a.hi == vec4(0u, 0u, 0u, 0u)); -} - -fn eq_packed(a: PackedField, b: PackedField) -> bool { - return all(a.lo == b.lo) && all(a.hi == b.hi); -} - -fn get_zero_packed() -> PackedField { - return PackedField(vec4(0u), vec4(0u)); -} - fn get_r() -> BigInt { var r: BigInt; {{{ r_limbs }}} return r; } -fn get_p_packed() -> PackedField { - var p: BigInt = get_p(); - return pack_field(&p); -} - -fn get_r_packed() -> PackedField { - var r: BigInt = get_r(); - return pack_field(&r); -} - -fn mont_p(a: PackedField, b: PackedField) -> PackedField { - var a_l = unpack_field(a); - var b_l = unpack_field(b); - var out = montgomery_product(&a_l, &b_l); - return pack_field(&out); -} - -fn fr_add_p(a: PackedField, b: PackedField) -> PackedField { - var a_l = unpack_field(a); - var b_l = unpack_field(b); - var out = fr_add(&a_l, &b_l); - return pack_field(&out); -} - -fn fr_sub_p(a: PackedField, b: PackedField) -> PackedField { - var a_l = unpack_field(a); - var b_l = unpack_field(b); - var out = fr_sub(&a_l, &b_l); - return pack_field(&out); +fn field_load_ro(idx: u32, src: ptr>, read>) -> BigInt { + var w: array; + let q0 = (*src)[2u * idx]; + let q1 = (*src)[2u * idx + 1u]; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); } -fn fr_neg_p(a: PackedField) -> PackedField { - var a_l = unpack_field(a); - var p_l: BigInt = get_p(); - var out: BigInt; - let _b = bigint_sub(&p_l, &a_l, &out); - return pack_field(&out); +fn field_load_rw(idx: u32, src: ptr>, read_write>) -> BigInt { + var w: array; + let q0 = (*src)[2u * idx]; + let q1 = (*src)[2u * idx + 1u]; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); } -fn fr_inv_p(a: PackedField) -> PackedField { - let a_l = unpack_field(a); - var out = fr_inv_by_a(a_l); - return pack_field(&out); +fn field_store(idx: u32, dst: ptr>, read_write>, val: ptr) { + let w = pack_limbs_to_256(val); + (*dst)[2u * idx] = vec4(w[0], w[1], w[2], w[3]); + (*dst)[2u * idx + 1u] = vec4(w[4], w[5], w[6], w[7]); } From e0315489fea58073595b4494f9e8769493c97e38 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 19 May 2026 14:11:54 +0000 Subject: [PATCH 05/33] fix(bb/msm/bench): hoist writeBuffer out of timed dispatch loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per-rep writeBuffer of runningXBuf/runningYBuf was inside the timed dispatch() call, adding ~4 MB of CPU→GPU upload to every sample. At M2 staging-upload rates that's ~1 ms per writeBuffer × 2 buffers ≈ 2 ms per rep. With observed per-dispatch wall ≈ 3.5 ms, the writeBuffer was nearly half the measured time. Now: writeBuffer once at setup (so the initial state is correct for the correctness check on warmup dispatch), and only on the explicit resetState=true path. Timed reps run with resetState=false — the kernel keeps modifying running_x/y in-place, which doesn't affect the throughput measurement (same load/sub/mul/store work per pair regardless of input values). Expected effect: my measured 55-76 ns/pair drops by the writeBuffer share once re-measured on BS. --- .../ts/dev/msm-webgpu/bench-fused-wg-scan.ts | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts index 71b68c99284c..264bbc0a6f77 100644 --- a/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts +++ b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts @@ -318,9 +318,15 @@ async function runOne( ], }); - const dispatch = async () => { - device.queue.writeBuffer(runningXBuf, 0, runningXAB); - device.queue.writeBuffer(runningYBuf, 0, runningYAB); + device.queue.writeBuffer(runningXBuf, 0, runningXAB); + device.queue.writeBuffer(runningYBuf, 0, runningYAB); + + const dispatch = async (resetState: boolean) => { + if (resetState) { + device.queue.writeBuffer(runningXBuf, 0, runningXAB); + device.queue.writeBuffer(runningYBuf, 0, runningYAB); + await device.queue.onSubmittedWorkDone(); + } const encoder = device.createCommandEncoder(); const pass = encoder.beginComputePass(); pass.setPipeline(pipeline); @@ -333,7 +339,7 @@ async function runOne( return performance.now() - t0; }; - await dispatch(); + await dispatch(true); log('info', 'warmup dispatch returned'); let correctness: 'pass' | 'fail' | 'skipped' = 'skipped'; @@ -408,7 +414,7 @@ async function runOne( const samples: number[] = []; for (let r = 0; r < reps; r++) { - samples.push(await dispatch()); + samples.push(await dispatch(false)); } const med = median(samples); const mn = Math.min(...samples); From c6a8fb0cce2866c0adcd258baaf5ad81dba736d1 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 19 May 2026 16:17:40 +0000 Subject: [PATCH 06/33] feat(bb/msm): standalone ba_rev_packed_carry bench (per-thread BS, packed I/O) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Recreate the lost ba_rev_packed_carry_bench WGSL kernel: per-thread descending suffix-product over BS pairs, single fr_inv_by_a per thread, ascending lean affine apply. Packed 8 x u32 storage at the I/O boundary only; every kernel-local var lives as 13-bit BigInt limbs. No bucket indirection, no scheduler inputs — flat addressing only — so this is the idealised standalone setting we use to calibrate the M2 22-24 ns/pair target the prior session reproduced (iter24 rev_packed_s16 = 24.22 ns/pair) and to quantify the gap from standalone to the integrated fused round kernel. Wiring: - new WGSL template at wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl - gen_ba_rev_packed_carry_bench_shader(tpb, bs) on ShaderManager - dev/msm-webgpu/bench-ba-rev-packed-carry.{ts,html} host harness with noble on-curve oracle (R = P + Q), sweep over BS at fixed TPB - run-browserstack.mjs pageMap entry "bench-ba-rev-packed-carry" --- .../msm-webgpu/bench-ba-rev-packed-carry.html | 37 ++ .../msm-webgpu/bench-ba-rev-packed-carry.ts | 519 ++++++++++++++++++ .../msm-webgpu/scripts/run-browserstack.mjs | 1 + .../ts/src/msm_webgpu/cuzk/shader_manager.ts | 47 ++ .../src/msm_webgpu/wgsl/_generated/shaders.ts | 121 +++- .../ba_rev_packed_carry_bench.template.wgsl | 117 ++++ 6 files changed, 841 insertions(+), 1 deletion(-) create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.html create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.ts create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl diff --git a/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.html b/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.html new file mode 100644 index 000000000000..3f67c1b1f553 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.html @@ -0,0 +1,37 @@ + + + + + ba_rev_packed_carry standalone batch-affine bench (WebGPU) + + + +

ba_rev_packed_carry standalone batch-affine bench (WebGPU)

+

Query params: ?reps=R&total=N&tpb=T&bs=S&skip_correctness=1

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.ts b/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.ts new file mode 100644 index 000000000000..aa0a3dae542d --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.ts @@ -0,0 +1,519 @@ +/// +// Standalone WebGPU bench + correctness oracle for the +// `ba_rev_packed_carry` batch-affine scheme: per-thread descending +// suffix-product, single fr_inv_by_a per thread, ascending lean +// apply, with packed 8x u32 storage at the I/O boundary and 13-bit +// BigInt limbs in every register-resident variable. +// +// Inputs: TOTAL_PAIRS on-curve BN254 G1 affine pairs (P_i, Q_i), +// stored flat. Thread `tid` consumes pairs[tid * BS .. (tid+1) * BS). +// The kernel writes R_i = P_i + Q_i to outputs_x[i] / outputs_y[i]; +// we decode packed Mont form back to canonical and compare to noble's +// reference P.add(Q). +// +// Sweep dimension: BS (per-thread batch size) at fixed TPB=64. Default +// sweep covers BS in {8, 12, 16, 20, 24} to bracket the M2 sweet spot. +// Override via ?bs=S or ?tpb=T. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { makeResultsClient } from './results_post.js'; +import { bn254 } from '@noble/curves/bn254'; + +const G1 = bn254.G1.ProjectivePoint; +const FR_ORDER = bn254.fields.Fr.ORDER; + +const DEFAULT_TOTAL_PAIRS = 1 << 16; +let TOTAL_PAIRS = DEFAULT_TOTAL_PAIRS; + +const DEFAULT_TPB = 64; +let TPB = DEFAULT_TPB; +const DEFAULT_BS_SWEEP: readonly number[] = [8, 12, 16, 20, 24]; +let BS_SWEEP: readonly number[] = DEFAULT_BS_SWEEP; + +let SKIP_CORRECTNESS = false; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + for (;;) { + let v = 0n; + for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff); + v &= (1n << BigInt(bitlen)) - 1n; + if (v > 0n && v < p) return v; + } +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +function biToLe32u32(v: bigint): Uint32Array { + const out = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + out[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return out; +} + +function le32u32ToBi(u32: Uint32Array, off: number): bigint { + let v = 0n; + for (let i = 7; i >= 0; i--) v = (v << 32n) | BigInt(u32[off + i] >>> 0); + return v; +} + +interface PerSizeResult { + bs: number; + tpb: number; + num_wgs: number; + total_pairs: number; + median_ms: number; + min_ms: number; + max_ms: number; + ns_per_pair: number; + samples_ms: number[]; + correctness: 'pass' | 'fail' | 'skipped'; + correctness_first_fail?: { i: number; expected_x: string; got_x: string; expected_y: string; got_y: string }; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { reps: number; total: number; tpb: number; bs_sweep: readonly number[]; skip_correctness: boolean } | null; + results: PerSizeResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { + state: 'boot', + params: null, + results: [], + error: null, + log: [], +}; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const resultsClient = makeResultsClient({ page: 'bench-ba-rev-packed-carry' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, + params: benchState.params, + results: benchState.results, + error: benchState.error, + log: benchState.log, + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-ba-rev-packed-carry] ${msg}`); +} + +async function createPipeline( + device: GPUDevice, + code: string, + cacheKey: string, +): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const msg of info.messages) { + const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`; + if (msg.type === 'error') { + console.error(line); + log('err', line); + errLines.push(line); + hasError = true; + } else { + console.warn(line); + log('warn', line); + } + } + if (hasError) { + throw new Error(`WGSL compile failed for ${cacheKey}: ${errLines.join(' | ')}`); + } + const layout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 5, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + ], + }); + const pipeline = await device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); + return { pipeline, layout }; +} + +interface PointPair { + p: { x: bigint; y: bigint }; + q: { x: bigint; y: bigint }; + r: { x: bigint; y: bigint }; +} + +function buildPairs(n: number, seed: number): PointPair[] { + const rng = makeRng(seed); + const out: PointPair[] = []; + for (let i = 0; i < n; i++) { + let p: { x: bigint; y: bigint }; + let q: { x: bigint; y: bigint }; + let r: { x: bigint; y: bigint }; + for (;;) { + const sp = randomBelow(FR_ORDER, rng); + const sq = randomBelow(FR_ORDER, rng); + const pp = G1.BASE.multiply(sp); + const qp = G1.BASE.multiply(sq); + if (pp.is0() || qp.is0()) continue; + const pa = pp.toAffine(); + const qa = qp.toAffine(); + if (pa.x === qa.x) continue; + const rp = pp.add(qp); + if (rp.is0()) continue; + const ra = rp.toAffine(); + p = pa; + q = qa; + r = ra; + break; + } + out.push({ p, q, r }); + } + return out; +} + +async function runOne( + device: GPUDevice, + sm: ShaderManager, + bs: number, + reps: number, + R: bigint, + p: bigint, + pairs: PointPair[], +): Promise { + const perThread = bs; + if (TOTAL_PAIRS % (TPB * perThread) !== 0) { + throw new Error( + `TOTAL_PAIRS=${TOTAL_PAIRS} must be a multiple of TPB*BS=${TPB * perThread}`, + ); + } + const totalThreads = TOTAL_PAIRS / perThread; + const numWgs = totalThreads / TPB; + log('info', `=== BS=${bs}: TPB=${TPB} num_threads=${totalThreads} num_WGs=${numWgs}`); + + const code = sm.gen_ba_rev_packed_carry_bench_shader(TPB, bs); + const cacheKey = `bench-ba-rev-packed-carry-T${TPB}-S${bs}`; + log('info', `compiling shader (${code.length} chars)`); + (window as unknown as Record)[`__shader_bs${bs}`] = code; + const { pipeline, layout } = await createPipeline(device, code, cacheKey); + + const fieldBytes = 32; // 8 x u32 = 2 vec4 + const bufBytes = TOTAL_PAIRS * fieldBytes; + + const pxAB = new ArrayBuffer(bufBytes); + const pyAB = new ArrayBuffer(bufBytes); + const qxAB = new ArrayBuffer(bufBytes); + const qyAB = new ArrayBuffer(bufBytes); + + const px32 = new Uint32Array(pxAB); + const py32 = new Uint32Array(pyAB); + const qx32 = new Uint32Array(qxAB); + const qy32 = new Uint32Array(qyAB); + + for (let i = 0; i < TOTAL_PAIRS; i++) { + const { p: pp, q: qq } = pairs[i]; + const pxM = (pp.x * R) % p; + const pyM = (pp.y * R) % p; + const qxM = (qq.x * R) % p; + const qyM = (qq.y * R) % p; + px32.set(biToLe32u32(pxM), i * 8); + py32.set(biToLe32u32(pyM), i * 8); + qx32.set(biToLe32u32(qxM), i * 8); + qy32.set(biToLe32u32(qyM), i * 8); + } + + const mkSb = (size: number, copyDst: boolean, copySrc: boolean): GPUBuffer => { + let usage = GPUBufferUsage.STORAGE; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + return device.createBuffer({ size, usage }); + }; + + const pxBuf = mkSb(bufBytes, true, false); + const pyBuf = mkSb(bufBytes, true, false); + const qxBuf = mkSb(bufBytes, true, false); + const qyBuf = mkSb(bufBytes, true, false); + const oxBuf = mkSb(bufBytes, false, true); + const oyBuf = mkSb(bufBytes, false, true); + + device.queue.writeBuffer(pxBuf, 0, pxAB); + device.queue.writeBuffer(pyBuf, 0, pyAB); + device.queue.writeBuffer(qxBuf, 0, qxAB); + device.queue.writeBuffer(qyBuf, 0, qyAB); + + const bindGroup = device.createBindGroup({ + layout, + entries: [ + { binding: 0, resource: { buffer: pxBuf } }, + { binding: 1, resource: { buffer: pyBuf } }, + { binding: 2, resource: { buffer: qxBuf } }, + { binding: 3, resource: { buffer: qyBuf } }, + { binding: 4, resource: { buffer: oxBuf } }, + { binding: 5, resource: { buffer: oyBuf } }, + ], + }); + + const dispatch = async (): Promise => { + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + const t0 = performance.now(); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + return performance.now() - t0; + }; + + await dispatch(); + log('info', 'warmup dispatch returned'); + + let correctness: 'pass' | 'fail' | 'skipped' = 'skipped'; + let correctness_first_fail: PerSizeResult['correctness_first_fail']; + + if (!SKIP_CORRECTNESS) { + const stagingX = device.createBuffer({ + size: bufBytes, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + const stagingY = device.createBuffer({ + size: bufBytes, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(oxBuf, 0, stagingX, 0, bufBytes); + enc.copyBufferToBuffer(oyBuf, 0, stagingY, 0, bufBytes); + device.queue.submit([enc.finish()]); + await Promise.all([stagingX.mapAsync(GPUMapMode.READ), stagingY.mapAsync(GPUMapMode.READ)]); + const gpuX = new Uint32Array(stagingX.getMappedRange().slice(0)); + const gpuY = new Uint32Array(stagingY.getMappedRange().slice(0)); + stagingX.unmap(); + stagingY.unmap(); + stagingX.destroy(); + stagingY.destroy(); + + const rInv = (() => { + let g = R % p; + let r = p; + let x = 0n; + let y = 1n; + while (g !== 0n) { + const q = r / g; + [r, g] = [g, r - q * g]; + [x, y] = [y, x - q * y]; + } + return ((x % p) + p) % p; + })(); + + let mismatches = 0; + const MAX_REPORTED = 4; + for (let i = 0; i < TOTAL_PAIRS; i++) { + const gxM = le32u32ToBi(gpuX, i * 8); + const gyM = le32u32ToBi(gpuY, i * 8); + const gx = (gxM * rInv) % p; + const gy = (gyM * rInv) % p; + const ex = pairs[i].r.x; + const ey = pairs[i].r.y; + if (gx !== ex || gy !== ey) { + mismatches++; + if (!correctness_first_fail) { + correctness_first_fail = { + i, + expected_x: ex.toString(), + got_x: gx.toString(), + expected_y: ey.toString(), + got_y: gy.toString(), + }; + } + if (mismatches > MAX_REPORTED) break; + } + } + correctness = mismatches === 0 ? 'pass' : 'fail'; + if (mismatches === 0) { + log('ok', `correctness: pass (${TOTAL_PAIRS}/${TOTAL_PAIRS} pairs match noble reference)`); + } else { + log('err', `correctness: FAIL (${mismatches}+ mismatches; first @ i=${correctness_first_fail!.i})`); + log('err', ` expected R.x = ${correctness_first_fail!.expected_x.slice(0, 24)}...`); + log('err', ` got R.x = ${correctness_first_fail!.got_x.slice(0, 24)}...`); + } + } + + const samples: number[] = []; + for (let r = 0; r < reps; r++) { + samples.push(await dispatch()); + } + const med = median(samples); + const mn = Math.min(...samples); + const mx = Math.max(...samples); + const nsPerPair = (med * 1e6) / TOTAL_PAIRS; + + log( + correctness === 'fail' ? 'err' : 'ok', + `BS=${bs}: median=${med.toFixed(3)}ms min=${mn.toFixed(3)}ms max=${mx.toFixed(3)}ms ns/pair=${nsPerPair.toFixed(1)} correctness=${correctness}`, + ); + + pxBuf.destroy(); + pyBuf.destroy(); + qxBuf.destroy(); + qyBuf.destroy(); + oxBuf.destroy(); + oyBuf.destroy(); + + return { + bs, + tpb: TPB, + num_wgs: numWgs, + total_pairs: TOTAL_PAIRS, + median_ms: med, + min_ms: mn, + max_ms: mx, + ns_per_pair: nsPerPair, + samples_ms: samples, + correctness, + correctness_first_fail, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + const reps = parseInt(qp.get('reps') ?? '5', 10); + if (!Number.isFinite(reps) || reps <= 0 || reps > 50) { + throw new Error(`?reps must be in (0, 50], got ${qp.get('reps')}`); + } + const totalStr = qp.get('total'); + if (totalStr !== null) { + const total = parseInt(totalStr, 10); + if (!Number.isFinite(total) || total <= 0 || total > (1 << 20)) { + throw new Error(`?total must be in (0, 2^20], got ${totalStr}`); + } + TOTAL_PAIRS = total; + } + const tpbStr = qp.get('tpb'); + if (tpbStr !== null) { + const tpb = parseInt(tpbStr, 10); + if (!Number.isFinite(tpb) || tpb <= 0 || tpb > 1024) { + throw new Error(`?tpb must be in (0, 1024], got ${tpbStr}`); + } + TPB = tpb; + } + const bsStr = qp.get('bs'); + if (bsStr !== null) { + const list = bsStr.split(',').map(s => parseInt(s, 10)); + for (const s of list) { + if (!Number.isFinite(s) || s <= 0 || s > 64) { + throw new Error(`?bs entries must be in (0, 64], got ${s}`); + } + } + BS_SWEEP = list; + } + for (const s of BS_SWEEP) { + if (TOTAL_PAIRS % (TPB * s) !== 0) { + throw new Error(`BS=${s} with TPB=${TPB} does not divide TOTAL_PAIRS=${TOTAL_PAIRS}`); + } + } + if (qp.get('skip_correctness') === '1') { + SKIP_CORRECTNESS = true; + } + return { reps, total: TOTAL_PAIRS, tpb: TPB, bs_sweep: BS_SWEEP, skip_correctness: SKIP_CORRECTNESS }; +} + +async function main() { + try { + if (!('gpu' in navigator)) { + throw new Error('navigator.gpu missing — WebGPU not available'); + } + const params = parseParams(); + benchState.params = params; + log( + 'info', + `params: reps=${params.reps} total=${params.total} tpb=${params.tpb} bs=[${params.bs_sweep.join(',')}] skip_correctness=${params.skip_correctness}`, + ); + + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + + log('info', `generating ${TOTAL_PAIRS} on-curve pairs via noble (this can take a few seconds)…`); + const t0 = performance.now(); + const pairs = buildPairs(TOTAL_PAIRS, 0xc0ffee); + log('info', `pair generation done in ${(performance.now() - t0).toFixed(0)} ms`); + + const sm = new ShaderManager(4, TOTAL_PAIRS, BN254_CURVE_CONFIG, false); + + for (const bs of BS_SWEEP) { + try { + const r = await runOne(device, sm, bs, params.reps, R, p, pairs); + benchState.results.push(r); + resultsClient.postProgress({ kind: 'batch_done', bs, median_ms: r.median_ms, ns_per_pair: r.ns_per_pair, correctness: r.correctness }); + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + log('err', `BS=${bs} failed: ${msg} — STOPPING sweep at first failure`); + benchState.state = 'error'; + benchState.error = msg; + return; + } + } + + benchState.state = 'done'; + log('ok', 'all sizes done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + }) + .finally(() => { + postFinal().catch(() => {}); + }); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs index c503e9be2230..0f9c87ba5947 100644 --- a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs +++ b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs @@ -128,6 +128,7 @@ if (!TARGETS[argv.target]) { const pageMap = { "bench-batch-affine": "/dev/msm-webgpu/bench-batch-affine.html", "bench-fused-wg-scan": "/dev/msm-webgpu/bench-fused-wg-scan.html", + "bench-ba-rev-packed-carry": "/dev/msm-webgpu/bench-ba-rev-packed-carry.html", "bench-smvp-tree": "/dev/msm-webgpu/bench-smvp-tree.html", sanity: "/dev/msm-webgpu/index.html", }; diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index 05948ff3f852..aa0f18ee121f 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -5,6 +5,7 @@ import { batch_affine_apply as batch_affine_apply_shader, batch_affine_apply_scatter as batch_affine_apply_scatter_shader, batch_affine_dispatch_args as batch_affine_dispatch_args_shader, + ba_rev_packed_carry_bench as ba_rev_packed_carry_bench_shader, bench_batch_affine as bench_batch_affine_shader, batch_affine_finalize as batch_affine_finalize_shader, batch_affine_finalize_apply as batch_affine_finalize_apply_shader, @@ -764,6 +765,52 @@ ${packLines.join('\n')} ); } + /** + * Standalone single-dispatch microbench for the ba_rev_packed_carry + * batch-affine scheme — packed 8x u32 storage, per-thread BS-pair + * descending suffix-product, single fr_inv_by_a per thread, ascending + * lean affine apply. No bucket indirection, no scheduler inputs. Used + * to validate that the M2 22-24 ns/pair number is achievable in the + * idealised standalone setting and to quantify the gap to integration. + */ + public gen_ba_rev_packed_carry_bench_shader(tpb: number, bs: number): string { + if (tpb <= 0 || bs <= 0 || !Number.isInteger(tpb) || !Number.isInteger(bs)) { + throw new Error(`gen_ba_rev_packed_carry_bench_shader: tpb (${tpb}) and bs (${bs}) must be positive integers`); + } + const dec = this.decoupledPackUnpackWgsl(); + return mustache.render( + ba_rev_packed_carry_bench_shader, + { + tpb, + bs, + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, + dec_unpack: dec.unpack, + dec_pack: dec.pack, + recompile: this.recompile, + }, + { + structs, + bigint_funcs, + montgomery_product_funcs: this.mont_product_src, + field_funcs, + fr_pow_funcs, + bigint_by_funcs, + by_inverse_a_funcs, + packed_field_funcs, + }, + ); + } + public gen_batch_affine_init_shader(workgroup_size: number, packed = false): string { const dec = this.decoupledPackUnpackWgsl(); return mustache.render( diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index 743b189213a0..ede4abcdb80b 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1,6 +1,6 @@ // AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT. // Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate. -// 50 shader sources inlined. +// 51 shader sources inlined. /* eslint-disable */ @@ -1351,6 +1351,125 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_rev_packed_carry_bench = `{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +{{> packed_field_funcs }} + +// Standalone single-dispatch microbench for the ba_rev_packed_carry +// batch-affine EC-add scheme: +// per-thread descending suffix-product + single fr_inv_by_a + +// ascending lean-apply, with packed 8x u32 storage at the I/O +// boundary and 13-bit BigInt limbs in every register-resident var. +// +// Each thread independently processes BS consecutive pairs from a flat +// pool (no bucket indirection, no scheduler input). Threads in a +// workgroup share no data; the only reason for TPB threads/workgroup +// is SIMD-lockstep execution of the BY safegcd inversion so its +// latency amortises across the wave. +// +// DISPATCH +// workgroups = TOTAL_PAIRS / (TPB * BS), threads/wg = TPB. +// Thread gid = wid.x * TPB + lid.x owns pairs [gid*BS, (gid+1)*BS). +// +// PHASES (entirely per-thread; no workgroup memory) +// A) Descending suffix-product over BS pairs. dx_k = Q.x_k - P.x_k. +// suf[k] = product_{j >= k} dx_j. acc threads through dx_{BS-1}, +// dx_{BS-2}, ..., dx_0. +// B) Single inv = fr_inv_by_a(suf[0]). +// C) Ascending lean apply: for k = 0..BS-1 +// inv_dx_k = (k+1 < BS) ? inv * suf[k+1] : inv +// lambda = (Q.y - P.y) * inv_dx_k +// R.x = lambda^2 - P.x - Q.x +// R.y = lambda * (P.x - R.x) - P.y +// if k+1 < BS: inv = inv * dx_k (forward-propagate) +// +// LOOP BOUNDS — every loop bound is a compile-time Mustache const +// (BS, NUM_WORDS). No data-dependent unbounded loops. + +const TPB: u32 = {{ tpb }}u; +const BS: u32 = {{ bs }}u; + +@group(0) @binding(0) var inputs_p_x: array>; +@group(0) @binding(1) var inputs_p_y: array>; +@group(0) @binding(2) var inputs_q_x: array>; +@group(0) @binding(3) var inputs_q_y: array>; +@group(0) @binding(4) var outputs_x: array>; +@group(0) @binding(5) var outputs_y: array>; + +@compute +@workgroup_size({{ tpb }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let tid = gid.x; + let base = tid * BS; + + // Phase A — descending suffix-product over BS pairs. + var suf: array; + var acc: BigInt; + for (var jj: u32 = 0u; jj < BS; jj = jj + 1u) { + let k = BS - 1u - jj; + let idx = base + k; + var p_x: BigInt = field_load_ro(idx, &inputs_p_x); + var q_x: BigInt = field_load_ro(idx, &inputs_q_x); + var dx: BigInt = fr_sub(&q_x, &p_x); + if (jj == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + suf[k] = acc; + } + + // Phase B — single inversion per thread. + var inv: BigInt = fr_inv_by_a(suf[0]); + + // Phase C — ascending lean apply. + for (var k: u32 = 0u; k < BS; k = k + 1u) { + let idx = base + k; + var p_x: BigInt = field_load_ro(idx, &inputs_p_x); + var p_y: BigInt = field_load_ro(idx, &inputs_p_y); + var q_x: BigInt = field_load_ro(idx, &inputs_q_x); + var q_y: BigInt = field_load_ro(idx, &inputs_q_y); + + var inv_dx: BigInt; + if (k + 1u < BS) { + var sp = suf[k + 1u]; + inv_dx = montgomery_product(&inv, &sp); + } else { + inv_dx = inv; + } + + var dy: BigInt = fr_sub(&q_y, &p_y); + var lambda: BigInt = montgomery_product(&dy, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_x); + r_x = fr_sub(&r_x, &q_x); + var dxb: BigInt = fr_sub(&p_x, &r_x); + var r_y: BigInt = montgomery_product(&lambda, &dxb); + r_y = fr_sub(&r_y, &p_y); + + field_store(idx, &outputs_x, &r_x); + field_store(idx, &outputs_y, &r_y); + + if (k + 1u < BS) { + var dxf: BigInt = fr_sub(&q_x, &p_x); + inv = montgomery_product(&inv, &dxf); + } + } + + {{{ recompile }}} +} +`; + export const barrett = `const W_MASK = {{ w_mask }}u; const SLACK = {{ slack }}u; diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl new file mode 100644 index 000000000000..51b795840159 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl @@ -0,0 +1,117 @@ +{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +{{> packed_field_funcs }} + +// Standalone single-dispatch microbench for the ba_rev_packed_carry +// batch-affine EC-add scheme: +// per-thread descending suffix-product + single fr_inv_by_a + +// ascending lean-apply, with packed 8x u32 storage at the I/O +// boundary and 13-bit BigInt limbs in every register-resident var. +// +// Each thread independently processes BS consecutive pairs from a flat +// pool (no bucket indirection, no scheduler input). Threads in a +// workgroup share no data; the only reason for TPB threads/workgroup +// is SIMD-lockstep execution of the BY safegcd inversion so its +// latency amortises across the wave. +// +// DISPATCH +// workgroups = TOTAL_PAIRS / (TPB * BS), threads/wg = TPB. +// Thread gid = wid.x * TPB + lid.x owns pairs [gid*BS, (gid+1)*BS). +// +// PHASES (entirely per-thread; no workgroup memory) +// A) Descending suffix-product over BS pairs. dx_k = Q.x_k - P.x_k. +// suf[k] = product_{j >= k} dx_j. acc threads through dx_{BS-1}, +// dx_{BS-2}, ..., dx_0. +// B) Single inv = fr_inv_by_a(suf[0]). +// C) Ascending lean apply: for k = 0..BS-1 +// inv_dx_k = (k+1 < BS) ? inv * suf[k+1] : inv +// lambda = (Q.y - P.y) * inv_dx_k +// R.x = lambda^2 - P.x - Q.x +// R.y = lambda * (P.x - R.x) - P.y +// if k+1 < BS: inv = inv * dx_k (forward-propagate) +// +// LOOP BOUNDS — every loop bound is a compile-time Mustache const +// (BS, NUM_WORDS). No data-dependent unbounded loops. + +const TPB: u32 = {{ tpb }}u; +const BS: u32 = {{ bs }}u; + +@group(0) @binding(0) var inputs_p_x: array>; +@group(0) @binding(1) var inputs_p_y: array>; +@group(0) @binding(2) var inputs_q_x: array>; +@group(0) @binding(3) var inputs_q_y: array>; +@group(0) @binding(4) var outputs_x: array>; +@group(0) @binding(5) var outputs_y: array>; + +@compute +@workgroup_size({{ tpb }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let tid = gid.x; + let base = tid * BS; + + // Phase A — descending suffix-product over BS pairs. + var suf: array; + var acc: BigInt; + for (var jj: u32 = 0u; jj < BS; jj = jj + 1u) { + let k = BS - 1u - jj; + let idx = base + k; + var p_x: BigInt = field_load_ro(idx, &inputs_p_x); + var q_x: BigInt = field_load_ro(idx, &inputs_q_x); + var dx: BigInt = fr_sub(&q_x, &p_x); + if (jj == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + suf[k] = acc; + } + + // Phase B — single inversion per thread. + var inv: BigInt = fr_inv_by_a(suf[0]); + + // Phase C — ascending lean apply. + for (var k: u32 = 0u; k < BS; k = k + 1u) { + let idx = base + k; + var p_x: BigInt = field_load_ro(idx, &inputs_p_x); + var p_y: BigInt = field_load_ro(idx, &inputs_p_y); + var q_x: BigInt = field_load_ro(idx, &inputs_q_x); + var q_y: BigInt = field_load_ro(idx, &inputs_q_y); + + var inv_dx: BigInt; + if (k + 1u < BS) { + var sp = suf[k + 1u]; + inv_dx = montgomery_product(&inv, &sp); + } else { + inv_dx = inv; + } + + var dy: BigInt = fr_sub(&q_y, &p_y); + var lambda: BigInt = montgomery_product(&dy, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_x); + r_x = fr_sub(&r_x, &q_x); + var dxb: BigInt = fr_sub(&p_x, &r_x); + var r_y: BigInt = montgomery_product(&lambda, &dxb); + r_y = fr_sub(&r_y, &p_y); + + field_store(idx, &outputs_x, &r_x); + field_store(idx, &outputs_y, &r_y); + + if (k + 1u < BS) { + var dxf: BigInt = fr_sub(&q_x, &p_x); + inv = montgomery_product(&inv, &dxf); + } + } + + {{{ recompile }}} +} From 151e4e4ef7277e047a5cf8c573984c373a8df20d Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 19 May 2026 17:01:16 +0000 Subject: [PATCH 07/33] feat(bb/msm): port recovered ba_rev_packed_carry kernel from eab3a3e MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the first-principles reconstruction with the canonical kernel recovered by the remote agent in commit eab3a3e. Differences that matter for the ~22 ns/pair M2 result: - bucket-accumulate streaming chain (A_{i+1} := P_i resident load-carry) instead of independent P+Q pairs — forward prefix-product, ONE fr_inv_by_a per S-chunk, backward peel with dx recomputed free - SoA-packed layout: one input buffer, 4 planes (A.x, A.y, P.x, P.y), each PG=2 vec4/elem; strided per-thread access e = t + i*T for full coalescing; separate 2-plane output buffer; params uniform = (N, T) - host harness measures DISP=8 back-to-back dispatches per timed sample (amortises submit+drain — the dominant source of the 29 vs 22 ns gap in the prior single-dispatch reconstruction), PAIRS=131072, sweep S in {16,32,64}, sanity = readNonZero on R.x plane shader_manager.gen_ba_rev_packed_carry_bench_shader now takes (workgroup_size, s) to match the recovered template's mustache vars. --- .../msm-webgpu/bench-ba-rev-packed-carry.html | 2 +- .../msm-webgpu/bench-ba-rev-packed-carry.ts | 497 ++++++++---------- .../ts/src/msm_webgpu/cuzk/shader_manager.ts | 23 +- .../src/msm_webgpu/wgsl/_generated/shaders.ts | 226 +++++--- .../ba_rev_packed_carry_bench.template.wgsl | 220 +++++--- 5 files changed, 497 insertions(+), 471 deletions(-) diff --git a/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.html b/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.html index 3f67c1b1f553..6676d4772bd2 100644 --- a/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.html +++ b/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.html @@ -30,7 +30,7 @@

ba_rev_packed_carry standalone batch-affine bench (WebGPU)

-

Query params: ?reps=R&total=N&tpb=T&bs=S&skip_correctness=1

+

Query params: ?reps=R&pairs=N&wgi=W&disp=D&s=A,B,C

diff --git a/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.ts b/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.ts index aa0a3dae542d..2e55c2c9b7e4 100644 --- a/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.ts +++ b/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.ts @@ -1,19 +1,15 @@ /// -// Standalone WebGPU bench + correctness oracle for the -// `ba_rev_packed_carry` batch-affine scheme: per-thread descending -// suffix-product, single fr_inv_by_a per thread, ascending lean -// apply, with packed 8x u32 storage at the I/O boundary and 13-bit -// BigInt limbs in every register-resident variable. +// Standalone WebGPU bench for the canonical ba_rev_packed_carry kernel +// (recovered from commit eab3a3e). SoA-packed 8x u32 storage across 4 +// input planes (A.x, A.y, P.x, P.y), strided per-thread access +// e = t + i*T, single fr_inv_by_a per S-chunk, lean affine apply with +// resident-accumulator load-carry. Each timed sample submits DISP back- +// to-back dispatches in one command encoder to amortise submit+drain. // -// Inputs: TOTAL_PAIRS on-curve BN254 G1 affine pairs (P_i, Q_i), -// stored flat. Thread `tid` consumes pairs[tid * BS .. (tid+1) * BS). -// The kernel writes R_i = P_i + Q_i to outputs_x[i] / outputs_y[i]; -// we decode packed Mont form back to canonical and compare to noble's -// reference P.add(Q). -// -// Sweep dimension: BS (per-thread batch size) at fixed TPB=64. Default -// sweep covers BS in {8, 12, 16, 20, 24} to bracket the M2 sweet spot. -// Override via ?bs=S or ?tpb=T. +// Math: bucket-accumulate streaming chain (R_i = A_i + P_i where A_0 is +// the seed, A_{i+1} := P_i). With random P.x != A.x inputs, every dx is +// nonzero so the batched inverse is well-defined. Sanity check reads +// the first packed elem of R.x and asserts non-zero. import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; @@ -21,20 +17,17 @@ import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; import { makeResultsClient } from './results_post.js'; -import { bn254 } from '@noble/curves/bn254'; - -const G1 = bn254.G1.ProjectivePoint; -const FR_ORDER = bn254.fields.Fr.ORDER; -const DEFAULT_TOTAL_PAIRS = 1 << 16; -let TOTAL_PAIRS = DEFAULT_TOTAL_PAIRS; +const DEFAULT_PAIRS = 1 << 17; // 131072 +const DEFAULT_WGI = 64; +const DEFAULT_DISP = 8; +const DEFAULT_S_SWEEP: readonly number[] = [16, 32, 64]; +const PG = 2; // 8 packed u32 / 4 = 2 vec4 groups per element -const DEFAULT_TPB = 64; -let TPB = DEFAULT_TPB; -const DEFAULT_BS_SWEEP: readonly number[] = [8, 12, 16, 20, 24]; -let BS_SWEEP: readonly number[] = DEFAULT_BS_SWEEP; - -let SKIP_CORRECTNESS = false; +let PAIRS = DEFAULT_PAIRS; +let WGI = DEFAULT_WGI; +let DISP = DEFAULT_DISP; +let S_SWEEP: readonly number[] = DEFAULT_S_SWEEP; function makeRng(seed: number): () => number { let state = (seed >>> 0) || 1; @@ -55,45 +48,69 @@ function randomBelow(p: bigint, rng: () => number): bigint { } } -function median(xs: number[]): number { - if (xs.length === 0) return NaN; - const s = xs.slice().sort((a, b) => a - b); - return s[Math.floor(s.length / 2)]; -} - -function biToLe32u32(v: bigint): Uint32Array { - const out = new Uint32Array(8); +function bigintToPackedU32x8(v: bigint): Uint32Array { + const w = new Uint32Array(8); let x = v; for (let i = 0; i < 8; i++) { - out[i] = Number(x & 0xffffffffn); + w[i] = Number(x & 0xffffffffn); x >>= 32n; } - return out; + return w; +} + +// SoA layout: 4 planes (A.x, A.y, P.x, P.y), each plane = PG * N vec4. +// For plane c and pair e, vec4 indices = (c * PG + v) * N + e for v in +// [0, PG). The flat Uint32Array index is that times 4. +function packAffineSoAPacked(pairs: number, R: bigint, p: bigint, rng: () => number): Uint32Array { + const N = pairs; + const buf = new Uint32Array(4 * PG * N * 4); + for (let e = 0; e < N; e++) { + let pxM: bigint; + let qxM: bigint; + do { + pxM = (randomBelow(p, rng) * R) % p; + qxM = (randomBelow(p, rng) * R) % p; + } while (pxM === qxM); + const coords = [pxM, (randomBelow(p, rng) * R) % p, qxM, (randomBelow(p, rng) * R) % p]; + for (let c = 0; c < 4; c++) { + const words = bigintToPackedU32x8(coords[c]); + for (let v = 0; v < PG; v++) { + const base = ((c * PG + v) * N + e) * 4; + buf[base + 0] = words[4 * v + 0]; + buf[base + 1] = words[4 * v + 1]; + buf[base + 2] = words[4 * v + 2]; + buf[base + 3] = words[4 * v + 3]; + } + } + } + return buf; } -function le32u32ToBi(u32: Uint32Array, off: number): bigint { - let v = 0n; - for (let i = 7; i >= 0; i--) v = (v << 32n) | BigInt(u32[off + i] >>> 0); - return v; +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; } interface PerSizeResult { - bs: number; - tpb: number; + s: number; + wgi: number; + T: number; num_wgs: number; - total_pairs: number; + pairs: number; + disp: number; + total_ops: number; median_ms: number; min_ms: number; max_ms: number; - ns_per_pair: number; + ns_per_op: number; samples_ms: number[]; - correctness: 'pass' | 'fail' | 'skipped'; - correctness_first_fail?: { i: number; expected_x: string; got_x: string; expected_y: string; got_y: string }; + sanity_ok: boolean; } interface BenchState { state: 'boot' | 'running' | 'done' | 'error'; - params: { reps: number; total: number; tpb: number; bs_sweep: readonly number[]; skip_correctness: boolean } | null; + params: { reps: number; pairs: number; wgi: number; disp: number; s_sweep: readonly number[] } | null; results: PerSizeResult[]; error: string | null; log: string[]; @@ -162,10 +179,8 @@ async function createPipeline( entries: [ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, - { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, - { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, - { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, - { binding: 5, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, ], }); const pipeline = await device.createComputePipelineAsync({ @@ -175,241 +190,143 @@ async function createPipeline( return { pipeline, layout }; } -interface PointPair { - p: { x: bigint; y: bigint }; - q: { x: bigint; y: bigint }; - r: { x: bigint; y: bigint }; +async function readNonZero(device: GPUDevice, out: GPUBuffer, u32Count: number): Promise { + const bytes = u32Count * 4; + const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(out, 0, staging, 0, bytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const u32 = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + for (let i = 0; i < u32.length; i++) { + if (u32[i] !== 0) return true; + } + return false; } -function buildPairs(n: number, seed: number): PointPair[] { - const rng = makeRng(seed); - const out: PointPair[] = []; - for (let i = 0; i < n; i++) { - let p: { x: bigint; y: bigint }; - let q: { x: bigint; y: bigint }; - let r: { x: bigint; y: bigint }; - for (;;) { - const sp = randomBelow(FR_ORDER, rng); - const sq = randomBelow(FR_ORDER, rng); - const pp = G1.BASE.multiply(sp); - const qp = G1.BASE.multiply(sq); - if (pp.is0() || qp.is0()) continue; - const pa = pp.toAffine(); - const qa = qp.toAffine(); - if (pa.x === qa.x) continue; - const rp = pp.add(qp); - if (rp.is0()) continue; - const ra = rp.toAffine(); - p = pa; - q = qa; - r = ra; - break; +async function timeDispatch( + device: GPUDevice, + pipeline: GPUComputePipeline, + bind: GPUBindGroup, + numWgs: number, + reps: number, + passes: number, +): Promise { + // warmup + { + const enc = device.createCommandEncoder(); + for (let pIdx = 0; pIdx < passes; pIdx++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + } + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + const samples: number[] = []; + for (let r = 0; r < reps; r++) { + const enc = device.createCommandEncoder(); + for (let pIdx = 0; pIdx < passes; pIdx++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); } - out.push({ p, q, r }); + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + samples.push(performance.now() - t0); } - return out; + return samples; } async function runOne( device: GPUDevice, sm: ShaderManager, - bs: number, + s: number, reps: number, R: bigint, p: bigint, - pairs: PointPair[], + seed: number, ): Promise { - const perThread = bs; - if (TOTAL_PAIRS % (TPB * perThread) !== 0) { - throw new Error( - `TOTAL_PAIRS=${TOTAL_PAIRS} must be a multiple of TPB*BS=${TPB * perThread}`, - ); + if (PAIRS % s !== 0) { + throw new Error(`PAIRS=${PAIRS} must be a multiple of S=${s}`); } - const totalThreads = TOTAL_PAIRS / perThread; - const numWgs = totalThreads / TPB; - log('info', `=== BS=${bs}: TPB=${TPB} num_threads=${totalThreads} num_WGs=${numWgs}`); + const T = PAIRS / s; + const numWgs = Math.ceil(T / WGI); + log('info', `=== S=${s}: PAIRS=${PAIRS} T=${T} WGI=${WGI} numWgs=${numWgs} DISP=${DISP}`); - const code = sm.gen_ba_rev_packed_carry_bench_shader(TPB, bs); - const cacheKey = `bench-ba-rev-packed-carry-T${TPB}-S${bs}`; + const code = sm.gen_ba_rev_packed_carry_bench_shader(WGI, s); + const cacheKey = `bench-ba-rev-packed-carry-W${WGI}-S${s}`; log('info', `compiling shader (${code.length} chars)`); - (window as unknown as Record)[`__shader_bs${bs}`] = code; + (window as unknown as Record)[`__shader_s${s}`] = code; const { pipeline, layout } = await createPipeline(device, code, cacheKey); - const fieldBytes = 32; // 8 x u32 = 2 vec4 - const bufBytes = TOTAL_PAIRS * fieldBytes; - - const pxAB = new ArrayBuffer(bufBytes); - const pyAB = new ArrayBuffer(bufBytes); - const qxAB = new ArrayBuffer(bufBytes); - const qyAB = new ArrayBuffer(bufBytes); - - const px32 = new Uint32Array(pxAB); - const py32 = new Uint32Array(pyAB); - const qx32 = new Uint32Array(qxAB); - const qy32 = new Uint32Array(qyAB); - - for (let i = 0; i < TOTAL_PAIRS; i++) { - const { p: pp, q: qq } = pairs[i]; - const pxM = (pp.x * R) % p; - const pyM = (pp.y * R) % p; - const qxM = (qq.x * R) % p; - const qyM = (qq.y * R) % p; - px32.set(biToLe32u32(pxM), i * 8); - py32.set(biToLe32u32(pyM), i * 8); - qx32.set(biToLe32u32(qxM), i * 8); - qy32.set(biToLe32u32(qyM), i * 8); - } - - const mkSb = (size: number, copyDst: boolean, copySrc: boolean): GPUBuffer => { - let usage = GPUBufferUsage.STORAGE; - if (copyDst) usage |= GPUBufferUsage.COPY_DST; - if (copySrc) usage |= GPUBufferUsage.COPY_SRC; - return device.createBuffer({ size, usage }); - }; - - const pxBuf = mkSb(bufBytes, true, false); - const pyBuf = mkSb(bufBytes, true, false); - const qxBuf = mkSb(bufBytes, true, false); - const qyBuf = mkSb(bufBytes, true, false); - const oxBuf = mkSb(bufBytes, false, true); - const oyBuf = mkSb(bufBytes, false, true); - - device.queue.writeBuffer(pxBuf, 0, pxAB); - device.queue.writeBuffer(pyBuf, 0, pyAB); - device.queue.writeBuffer(qxBuf, 0, qxAB); - device.queue.writeBuffer(qyBuf, 0, qyAB); + const rng = makeRng(seed); + const inU32 = packAffineSoAPacked(PAIRS, R, p, rng); + const inBuf = device.createBuffer({ + size: inU32.byteLength, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(inBuf, 0, inU32); + const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE }); + const outBytes = 2 * PG * PAIRS * 4 * 4; + const outBuf = device.createBuffer({ + size: outBytes, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, + }); + const paramsBuf = device.createBuffer({ + size: 16, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(paramsBuf, 0, new Uint32Array([PAIRS, T, 0, 0])); - const bindGroup = device.createBindGroup({ + const bind = device.createBindGroup({ layout, entries: [ - { binding: 0, resource: { buffer: pxBuf } }, - { binding: 1, resource: { buffer: pyBuf } }, - { binding: 2, resource: { buffer: qxBuf } }, - { binding: 3, resource: { buffer: qyBuf } }, - { binding: 4, resource: { buffer: oxBuf } }, - { binding: 5, resource: { buffer: oyBuf } }, + { binding: 0, resource: { buffer: inBuf } }, + { binding: 1, resource: { buffer: dummy } }, + { binding: 2, resource: { buffer: outBuf } }, + { binding: 3, resource: { buffer: paramsBuf } }, ], }); - const dispatch = async (): Promise => { - const encoder = device.createCommandEncoder(); - const pass = encoder.beginComputePass(); - pass.setPipeline(pipeline); - pass.setBindGroup(0, bindGroup); - pass.dispatchWorkgroups(numWgs, 1, 1); - pass.end(); - const t0 = performance.now(); - device.queue.submit([encoder.finish()]); - await device.queue.onSubmittedWorkDone(); - return performance.now() - t0; - }; - - await dispatch(); - log('info', 'warmup dispatch returned'); - - let correctness: 'pass' | 'fail' | 'skipped' = 'skipped'; - let correctness_first_fail: PerSizeResult['correctness_first_fail']; - - if (!SKIP_CORRECTNESS) { - const stagingX = device.createBuffer({ - size: bufBytes, - usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, - }); - const stagingY = device.createBuffer({ - size: bufBytes, - usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, - }); - const enc = device.createCommandEncoder(); - enc.copyBufferToBuffer(oxBuf, 0, stagingX, 0, bufBytes); - enc.copyBufferToBuffer(oyBuf, 0, stagingY, 0, bufBytes); - device.queue.submit([enc.finish()]); - await Promise.all([stagingX.mapAsync(GPUMapMode.READ), stagingY.mapAsync(GPUMapMode.READ)]); - const gpuX = new Uint32Array(stagingX.getMappedRange().slice(0)); - const gpuY = new Uint32Array(stagingY.getMappedRange().slice(0)); - stagingX.unmap(); - stagingY.unmap(); - stagingX.destroy(); - stagingY.destroy(); - - const rInv = (() => { - let g = R % p; - let r = p; - let x = 0n; - let y = 1n; - while (g !== 0n) { - const q = r / g; - [r, g] = [g, r - q * g]; - [x, y] = [y, x - q * y]; - } - return ((x % p) + p) % p; - })(); - - let mismatches = 0; - const MAX_REPORTED = 4; - for (let i = 0; i < TOTAL_PAIRS; i++) { - const gxM = le32u32ToBi(gpuX, i * 8); - const gyM = le32u32ToBi(gpuY, i * 8); - const gx = (gxM * rInv) % p; - const gy = (gyM * rInv) % p; - const ex = pairs[i].r.x; - const ey = pairs[i].r.y; - if (gx !== ex || gy !== ey) { - mismatches++; - if (!correctness_first_fail) { - correctness_first_fail = { - i, - expected_x: ex.toString(), - got_x: gx.toString(), - expected_y: ey.toString(), - got_y: gy.toString(), - }; - } - if (mismatches > MAX_REPORTED) break; - } - } - correctness = mismatches === 0 ? 'pass' : 'fail'; - if (mismatches === 0) { - log('ok', `correctness: pass (${TOTAL_PAIRS}/${TOTAL_PAIRS} pairs match noble reference)`); - } else { - log('err', `correctness: FAIL (${mismatches}+ mismatches; first @ i=${correctness_first_fail!.i})`); - log('err', ` expected R.x = ${correctness_first_fail!.expected_x.slice(0, 24)}...`); - log('err', ` got R.x = ${correctness_first_fail!.got_x.slice(0, 24)}...`); - } - } - - const samples: number[] = []; - for (let r = 0; r < reps; r++) { - samples.push(await dispatch()); - } + const samples = await timeDispatch(device, pipeline, bind, numWgs, reps, DISP); + const sanityOk = await readNonZero(device, outBuf, 8); const med = median(samples); - const mn = Math.min(...samples); - const mx = Math.max(...samples); - const nsPerPair = (med * 1e6) / TOTAL_PAIRS; + const totalOps = PAIRS * DISP; + const nsPerOp = (med * 1e6) / totalOps; log( - correctness === 'fail' ? 'err' : 'ok', - `BS=${bs}: median=${med.toFixed(3)}ms min=${mn.toFixed(3)}ms max=${mx.toFixed(3)}ms ns/pair=${nsPerPair.toFixed(1)} correctness=${correctness}`, + sanityOk ? 'ok' : 'err', + `S=${s}: median=${med.toFixed(3)}ms min=${Math.min(...samples).toFixed(3)}ms max=${Math.max(...samples).toFixed(3)}ms ns/op=${nsPerOp.toFixed(2)} sanity=${sanityOk ? 'OK' : 'FAIL'}`, ); - pxBuf.destroy(); - pyBuf.destroy(); - qxBuf.destroy(); - qyBuf.destroy(); - oxBuf.destroy(); - oyBuf.destroy(); + inBuf.destroy(); + dummy.destroy(); + outBuf.destroy(); + paramsBuf.destroy(); return { - bs, - tpb: TPB, + s, + wgi: WGI, + T, num_wgs: numWgs, - total_pairs: TOTAL_PAIRS, + pairs: PAIRS, + disp: DISP, + total_ops: totalOps, median_ms: med, - min_ms: mn, - max_ms: mx, - ns_per_pair: nsPerPair, + min_ms: Math.min(...samples), + max_ms: Math.max(...samples), + ns_per_op: nsPerOp, samples_ms: samples, - correctness, - correctness_first_fail, + sanity_ok: sanityOk, }; } @@ -419,41 +336,46 @@ function parseParams() { if (!Number.isFinite(reps) || reps <= 0 || reps > 50) { throw new Error(`?reps must be in (0, 50], got ${qp.get('reps')}`); } - const totalStr = qp.get('total'); - if (totalStr !== null) { - const total = parseInt(totalStr, 10); - if (!Number.isFinite(total) || total <= 0 || total > (1 << 20)) { - throw new Error(`?total must be in (0, 2^20], got ${totalStr}`); + const pairsStr = qp.get('pairs'); + if (pairsStr !== null) { + const v = parseInt(pairsStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > (1 << 20)) { + throw new Error(`?pairs must be in (0, 2^20], got ${pairsStr}`); } - TOTAL_PAIRS = total; + PAIRS = v; } - const tpbStr = qp.get('tpb'); - if (tpbStr !== null) { - const tpb = parseInt(tpbStr, 10); - if (!Number.isFinite(tpb) || tpb <= 0 || tpb > 1024) { - throw new Error(`?tpb must be in (0, 1024], got ${tpbStr}`); + const wgiStr = qp.get('wgi'); + if (wgiStr !== null) { + const v = parseInt(wgiStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > 1024) { + throw new Error(`?wgi must be in (0, 1024], got ${wgiStr}`); } - TPB = tpb; + WGI = v; } - const bsStr = qp.get('bs'); - if (bsStr !== null) { - const list = bsStr.split(',').map(s => parseInt(s, 10)); - for (const s of list) { - if (!Number.isFinite(s) || s <= 0 || s > 64) { - throw new Error(`?bs entries must be in (0, 64], got ${s}`); - } + const dispStr = qp.get('disp'); + if (dispStr !== null) { + const v = parseInt(dispStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > 64) { + throw new Error(`?disp must be in (0, 64], got ${dispStr}`); } - BS_SWEEP = list; + DISP = v; } - for (const s of BS_SWEEP) { - if (TOTAL_PAIRS % (TPB * s) !== 0) { - throw new Error(`BS=${s} with TPB=${TPB} does not divide TOTAL_PAIRS=${TOTAL_PAIRS}`); + const sStr = qp.get('s'); + if (sStr !== null) { + const list = sStr.split(',').map(v => parseInt(v, 10)); + for (const v of list) { + if (!Number.isFinite(v) || v <= 0 || v > 256) { + throw new Error(`?s entries must be in (0, 256], got ${v}`); + } } + S_SWEEP = list; } - if (qp.get('skip_correctness') === '1') { - SKIP_CORRECTNESS = true; + for (const v of S_SWEEP) { + if (PAIRS % v !== 0) { + throw new Error(`S=${v} does not divide PAIRS=${PAIRS}`); + } } - return { reps, total: TOTAL_PAIRS, tpb: TPB, bs_sweep: BS_SWEEP, skip_correctness: SKIP_CORRECTNESS }; + return { reps, pairs: PAIRS, wgi: WGI, disp: DISP, s_sweep: S_SWEEP }; } async function main() { @@ -465,7 +387,7 @@ async function main() { benchState.params = params; log( 'info', - `params: reps=${params.reps} total=${params.total} tpb=${params.tpb} bs=[${params.bs_sweep.join(',')}] skip_correctness=${params.skip_correctness}`, + `params: reps=${params.reps} pairs=${params.pairs} wgi=${params.wgi} disp=${params.disp} s=[${params.s_sweep.join(',')}]`, ); benchState.state = 'running'; @@ -476,21 +398,18 @@ async function main() { const miscParams = compute_misc_params(p, 13); const R = miscParams.r; - log('info', `generating ${TOTAL_PAIRS} on-curve pairs via noble (this can take a few seconds)…`); - const t0 = performance.now(); - const pairs = buildPairs(TOTAL_PAIRS, 0xc0ffee); - log('info', `pair generation done in ${(performance.now() - t0).toFixed(0)} ms`); - - const sm = new ShaderManager(4, TOTAL_PAIRS, BN254_CURVE_CONFIG, false); + const sm = new ShaderManager(4, PAIRS, BN254_CURVE_CONFIG, false); - for (const bs of BS_SWEEP) { + let seed = 0x7b10; + for (const s of S_SWEEP) { try { - const r = await runOne(device, sm, bs, params.reps, R, p, pairs); + const r = await runOne(device, sm, s, params.reps, R, p, seed); benchState.results.push(r); - resultsClient.postProgress({ kind: 'batch_done', bs, median_ms: r.median_ms, ns_per_pair: r.ns_per_pair, correctness: r.correctness }); + resultsClient.postProgress({ kind: 'batch_done', s, median_ms: r.median_ms, ns_per_op: r.ns_per_op, sanity_ok: r.sanity_ok }); + seed += 0x10; } catch (e) { const msg = e instanceof Error ? e.message : String(e); - log('err', `BS=${bs} failed: ${msg} — STOPPING sweep at first failure`); + log('err', `S=${s} failed: ${msg} — STOPPING sweep at first failure`); benchState.state = 'error'; benchState.error = msg; return; diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index aa0f18ee121f..775ad5a76ed0 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -766,23 +766,23 @@ ${packLines.join('\n')} } /** - * Standalone single-dispatch microbench for the ba_rev_packed_carry - * batch-affine scheme — packed 8x u32 storage, per-thread BS-pair - * descending suffix-product, single fr_inv_by_a per thread, ascending - * lean affine apply. No bucket indirection, no scheduler inputs. Used - * to validate that the M2 22-24 ns/pair number is achievable in the - * idealised standalone setting and to quantify the gap to integration. + * Standalone microbench for the recovered ba_rev_packed_carry kernel: + * SoA-packed 8x u32 storage across 4 input planes (A.x, A.y, P.x, P.y), + * strided per-thread access (e = t + i*T), forward prefix-product + + * single fr_inv_by_a + backward peel with resident-accumulator + * load-carry (A_{i+1} := P_i). The canonical kernel that hit ~22 + * ns/pair on M2 / Chrome 148. */ - public gen_ba_rev_packed_carry_bench_shader(tpb: number, bs: number): string { - if (tpb <= 0 || bs <= 0 || !Number.isInteger(tpb) || !Number.isInteger(bs)) { - throw new Error(`gen_ba_rev_packed_carry_bench_shader: tpb (${tpb}) and bs (${bs}) must be positive integers`); + public gen_ba_rev_packed_carry_bench_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_rev_packed_carry_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); } const dec = this.decoupledPackUnpackWgsl(); return mustache.render( ba_rev_packed_carry_bench_shader, { - tpb, - bs, + workgroup_size, + s, word_size: this.word_size, num_words: this.num_words, n0: this.n0, @@ -806,7 +806,6 @@ ${packLines.join('\n')} fr_pow_funcs, bigint_by_funcs, by_inverse_a_funcs, - packed_field_funcs, }, ); } diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index ede4abcdb80b..22e0f7916065 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1359,114 +1359,168 @@ export const ba_rev_packed_carry_bench = `{{> structs }} {{> bigint_by_funcs }} {{> by_inverse_a_funcs }} +// MSM-integrated bucket-accumulate batch-affine kernel — packed 8x u32 +// storage + decoupled (full-ILP) pack/unpack + reversed direction + +// resident-accumulator load-carry. Drives the canonical +// ba_rev_packed_carry benchmark that reached ~22 ns/pair on M2 / Chrome +// 148 (-55% vs the production batch-affine kernel). +// +// Math is byte-identical to ba_msm_bucket_bench: forward running +// prefix-product of the S dx values in a private array, ONE +// fr_inv_by_a per chunk of S, backward peel with the lean affine +// formula (dx recomputed free in the backward pass), resident +// accumulator A.x kept in registers across the whole chunk (load-carry: +// A_{i+1} := P_i so the forward and backward passes share one global +// P_i.x load per iteration). Same Karatsuba+Yuval montmul and BY-safegcd +// fr_inv_by_a as the production stack. +// +// The single structural change from ba_msm_bucket_bench: +// global storage is the packed 254-bit value stored as 8x u32 +// (32 bytes/elem == 2x vec4), not the 20x 13-bit-limb BigInt +// (80 bytes/elem == 5x vec4). Unpack into 20x13-bit limbs only +// in-register at load and repack on store. The pack/unpack is the +// decoupled (\`{{{ dec_unpack }}}\` / \`{{{ dec_pack }}}\`) full-ILP +// straight-line form: 20 mutually-independent compile-time-constant- +// indexed limb extractions, zero loop-carried bit-cursor dependency +// chain. This cuts global traffic 2.5x (the dominant cost in the +// memory-bound batch-affine kernel) at a sub-cycle in-register cost. +// +// LAYOUT: packed elem = 2 vec4; for each of the 4 input planes +// (A.x, A.y, P.x, P.y) and 2 output planes (R.x, R.y), plane c holds +// N elements at indices c*2*N + 2*e + {0,1}. params.x = N (total +// point-adds), params.y = T (thread count = N/S). +// +// Thread t streams points e = t + i*T for i in 0..S (strided => fully +// coalesced across the apply phase). The "left" operand of add i is the +// running accumulator A_i; A_0 is the per-thread seed (plane 0/1 at +// e=t), A_{i+1} := P_i (load-carry; same global address as forward +// pass's P_i load, no extra global traffic). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; // 8 u32 packed limbs / 4 = 2 vec4 groups + +@group(0) @binding(0) var inp: array>; +@group(0) @binding(1) var unused: array>; +@group(0) @binding(2) var outp: array>; +@group(0) @binding(3) var params: vec4; + {{{ dec_unpack }}} {{{ dec_pack }}} -{{> packed_field_funcs }} - -// Standalone single-dispatch microbench for the ba_rev_packed_carry -// batch-affine EC-add scheme: -// per-thread descending suffix-product + single fr_inv_by_a + -// ascending lean-apply, with packed 8x u32 storage at the I/O -// boundary and 13-bit BigInt limbs in every register-resident var. -// -// Each thread independently processes BS consecutive pairs from a flat -// pool (no bucket indirection, no scheduler input). Threads in a -// workgroup share no data; the only reason for TPB threads/workgroup -// is SIMD-lockstep execution of the BY safegcd inversion so its -// latency amortises across the wave. -// -// DISPATCH -// workgroups = TOTAL_PAIRS / (TPB * BS), threads/wg = TPB. -// Thread gid = wid.x * TPB + lid.x owns pairs [gid*BS, (gid+1)*BS). -// -// PHASES (entirely per-thread; no workgroup memory) -// A) Descending suffix-product over BS pairs. dx_k = Q.x_k - P.x_k. -// suf[k] = product_{j >= k} dx_j. acc threads through dx_{BS-1}, -// dx_{BS-2}, ..., dx_0. -// B) Single inv = fr_inv_by_a(suf[0]). -// C) Ascending lean apply: for k = 0..BS-1 -// inv_dx_k = (k+1 < BS) ? inv * suf[k+1] : inv -// lambda = (Q.y - P.y) * inv_dx_k -// R.x = lambda^2 - P.x - Q.x -// R.y = lambda * (P.x - R.x) - P.y -// if k+1 < BS: inv = inv * dx_k (forward-propagate) -// -// LOOP BOUNDS — every loop bound is a compile-time Mustache const -// (BS, NUM_WORDS). No data-dependent unbounded loops. +fn load_be_packed(plane_base: u32, e: u32, N: u32) -> BigInt { + // plane_base is in vec4 units; per plane: 2*N vec4 (PG=2). + let base = plane_base + PG * e; + let q0 = inp[base + 0u]; + let q1 = inp[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} -const TPB: u32 = {{ tpb }}u; -const BS: u32 = {{ bs }}u; +fn store_be_packed(plane_base: u32, e: u32, N: u32, val: ptr) { + let w = pack_limbs_to_256(val); + let base = plane_base + PG * e; + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} -@group(0) @binding(0) var inputs_p_x: array>; -@group(0) @binding(1) var inputs_p_y: array>; -@group(0) @binding(2) var inputs_q_x: array>; -@group(0) @binding(3) var inputs_q_y: array>; -@group(0) @binding(4) var outputs_x: array>; -@group(0) @binding(5) var outputs_y: array>; +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} -@compute -@workgroup_size({{ tpb }}) +@compute @workgroup_size({{ workgroup_size }}) fn main(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - let base = tid * BS; - - // Phase A — descending suffix-product over BS pairs. - var suf: array; - var acc: BigInt; - for (var jj: u32 = 0u; jj < BS; jj = jj + 1u) { - let k = BS - 1u - jj; - let idx = base + k; - var p_x: BigInt = field_load_ro(idx, &inputs_p_x); - var q_x: BigInt = field_load_ro(idx, &inputs_q_x); - var dx: BigInt = fr_sub(&q_x, &p_x); - if (jj == 0u) { + let N = params.x; + let T = params.y; + let t = gid.x; + if (t >= T) { return; } + + // Plane bases in vec4 units. Each plane spans PG*N vec4. + let plane = PG * N; + let ax_base = 0u * plane; + let ay_base = 1u * plane; + let px_base = 2u * plane; + let py_base = 3u * plane; + + // Resident accumulator A.x stays in registers across the whole + // chunk (drives the forward dx prefix chain). A.y is only needed in + // the backward peel and is re-loaded there from the same SoA plane. + var acc_x = load_be_packed(ax_base, t, N); + + // Forward pass: running prefix-product of the S dx values + // dx_i = P_i.x - A_i.x. A_i is the prefix accumulator (resident). + var pref: array; + var acc: BigInt = get_r(); + for (var i = 0u; i < S; i = i + 1u) { + let e = t + i * T; + var p_x = load_be_packed(px_base, e, N); + var dx = fr_sub(&p_x, &acc_x); + if (i == 0u) { acc = dx; } else { acc = montgomery_product(&acc, &dx); } - suf[k] = acc; - } - - // Phase B — single inversion per thread. - var inv: BigInt = fr_inv_by_a(suf[0]); - - // Phase C — ascending lean apply. - for (var k: u32 = 0u; k < BS; k = k + 1u) { - let idx = base + k; - var p_x: BigInt = field_load_ro(idx, &inputs_p_x); - var p_y: BigInt = field_load_ro(idx, &inputs_p_y); - var q_x: BigInt = field_load_ro(idx, &inputs_q_x); - var q_y: BigInt = field_load_ro(idx, &inputs_q_y); + pref[i] = acc; + // Resident accumulator advances along the streamed chain: + // A_0 is the seed, A_{i+1} := P_i. Points are independent + // (P_i.x != A_i.x) so every dx is a well-defined nonzero + // difference. inv_dx is deferred to the backward pass (ONE + // fr_inv_by_a per chunk of S); A stays in registers throughout. + acc_x = p_x; + } + + var inv: BigInt = fr_inv_by_a(acc); + + // Backward peel + lean affine formula (dx recomputed free). + for (var jj = 0u; jj < S; jj = jj + 1u) { + let i = S - 1u - jj; + let e = t + i * T; + var p_x = load_be_packed(px_base, e, N); + var p_y = load_be_packed(py_base, e, N); + + // A_i (left operand): A_0 is the seed, A_i = P_{i-1} for i>0 + // (matches the forward acc_x recurrence; points independent so + // dx = P_i.x - A_i.x is always well-defined and nonzero). + var a_x: BigInt; + var a_y: BigInt; + if (i == 0u) { + a_x = load_be_packed(ax_base, t, N); + a_y = load_be_packed(ay_base, t, N); + } else { + let ep = t + (i - 1u) * T; + a_x = load_be_packed(px_base, ep, N); + a_y = load_be_packed(py_base, ep, N); + } var inv_dx: BigInt; - if (k + 1u < BS) { - var sp = suf[k + 1u]; - inv_dx = montgomery_product(&inv, &sp); - } else { + if (i == 0u) { inv_dx = inv; + } else { + var pp = pref[i - 1u]; + inv_dx = montgomery_product(&inv, &pp); } - var dy: BigInt = fr_sub(&q_y, &p_y); - var lambda: BigInt = montgomery_product(&dy, &inv_dx); - var r_x: BigInt = montgomery_product(&lambda, &lambda); + var lambda = fr_sub(&p_y, &a_y); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &a_x); r_x = fr_sub(&r_x, &p_x); - r_x = fr_sub(&r_x, &q_x); - var dxb: BigInt = fr_sub(&p_x, &r_x); - var r_y: BigInt = montgomery_product(&lambda, &dxb); - r_y = fr_sub(&r_y, &p_y); + var r_y = fr_sub(&a_x, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &a_y); - field_store(idx, &outputs_x, &r_x); - field_store(idx, &outputs_y, &r_y); + store_be_packed(0u * plane, e, N, &r_x); + store_be_packed(1u * plane, e, N, &r_y); - if (k + 1u < BS) { - var dxf: BigInt = fr_sub(&q_x, &p_x); - inv = montgomery_product(&inv, &dxf); + if (i != 0u) { + var dx_back = fr_sub(&p_x, &a_x); + inv = montgomery_product(&inv, &dx_back); } } - - {{{ recompile }}} } `; diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl index 51b795840159..9ea7af05849e 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl @@ -6,112 +6,166 @@ {{> bigint_by_funcs }} {{> by_inverse_a_funcs }} -{{{ dec_unpack }}} - -{{{ dec_pack }}} - -{{> packed_field_funcs }} - -// Standalone single-dispatch microbench for the ba_rev_packed_carry -// batch-affine EC-add scheme: -// per-thread descending suffix-product + single fr_inv_by_a + -// ascending lean-apply, with packed 8x u32 storage at the I/O -// boundary and 13-bit BigInt limbs in every register-resident var. +// MSM-integrated bucket-accumulate batch-affine kernel — packed 8x u32 +// storage + decoupled (full-ILP) pack/unpack + reversed direction + +// resident-accumulator load-carry. Drives the canonical +// ba_rev_packed_carry benchmark that reached ~22 ns/pair on M2 / Chrome +// 148 (-55% vs the production batch-affine kernel). // -// Each thread independently processes BS consecutive pairs from a flat -// pool (no bucket indirection, no scheduler input). Threads in a -// workgroup share no data; the only reason for TPB threads/workgroup -// is SIMD-lockstep execution of the BY safegcd inversion so its -// latency amortises across the wave. +// Math is byte-identical to ba_msm_bucket_bench: forward running +// prefix-product of the S dx values in a private array, ONE +// fr_inv_by_a per chunk of S, backward peel with the lean affine +// formula (dx recomputed free in the backward pass), resident +// accumulator A.x kept in registers across the whole chunk (load-carry: +// A_{i+1} := P_i so the forward and backward passes share one global +// P_i.x load per iteration). Same Karatsuba+Yuval montmul and BY-safegcd +// fr_inv_by_a as the production stack. // -// DISPATCH -// workgroups = TOTAL_PAIRS / (TPB * BS), threads/wg = TPB. -// Thread gid = wid.x * TPB + lid.x owns pairs [gid*BS, (gid+1)*BS). +// The single structural change from ba_msm_bucket_bench: +// global storage is the packed 254-bit value stored as 8x u32 +// (32 bytes/elem == 2x vec4), not the 20x 13-bit-limb BigInt +// (80 bytes/elem == 5x vec4). Unpack into 20x13-bit limbs only +// in-register at load and repack on store. The pack/unpack is the +// decoupled (`{{{ dec_unpack }}}` / `{{{ dec_pack }}}`) full-ILP +// straight-line form: 20 mutually-independent compile-time-constant- +// indexed limb extractions, zero loop-carried bit-cursor dependency +// chain. This cuts global traffic 2.5x (the dominant cost in the +// memory-bound batch-affine kernel) at a sub-cycle in-register cost. // -// PHASES (entirely per-thread; no workgroup memory) -// A) Descending suffix-product over BS pairs. dx_k = Q.x_k - P.x_k. -// suf[k] = product_{j >= k} dx_j. acc threads through dx_{BS-1}, -// dx_{BS-2}, ..., dx_0. -// B) Single inv = fr_inv_by_a(suf[0]). -// C) Ascending lean apply: for k = 0..BS-1 -// inv_dx_k = (k+1 < BS) ? inv * suf[k+1] : inv -// lambda = (Q.y - P.y) * inv_dx_k -// R.x = lambda^2 - P.x - Q.x -// R.y = lambda * (P.x - R.x) - P.y -// if k+1 < BS: inv = inv * dx_k (forward-propagate) +// LAYOUT: packed elem = 2 vec4; for each of the 4 input planes +// (A.x, A.y, P.x, P.y) and 2 output planes (R.x, R.y), plane c holds +// N elements at indices c*2*N + 2*e + {0,1}. params.x = N (total +// point-adds), params.y = T (thread count = N/S). // -// LOOP BOUNDS — every loop bound is a compile-time Mustache const -// (BS, NUM_WORDS). No data-dependent unbounded loops. +// Thread t streams points e = t + i*T for i in 0..S (strided => fully +// coalesced across the apply phase). The "left" operand of add i is the +// running accumulator A_i; A_0 is the per-thread seed (plane 0/1 at +// e=t), A_{i+1} := P_i (load-carry; same global address as forward +// pass's P_i load, no extra global traffic). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; // 8 u32 packed limbs / 4 = 2 vec4 groups + +@group(0) @binding(0) var inp: array>; +@group(0) @binding(1) var unused: array>; +@group(0) @binding(2) var outp: array>; +@group(0) @binding(3) var params: vec4; + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +fn load_be_packed(plane_base: u32, e: u32, N: u32) -> BigInt { + // plane_base is in vec4 units; per plane: 2*N vec4 (PG=2). + let base = plane_base + PG * e; + let q0 = inp[base + 0u]; + let q1 = inp[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} -const TPB: u32 = {{ tpb }}u; -const BS: u32 = {{ bs }}u; +fn store_be_packed(plane_base: u32, e: u32, N: u32, val: ptr) { + let w = pack_limbs_to_256(val); + let base = plane_base + PG * e; + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} -@group(0) @binding(0) var inputs_p_x: array>; -@group(0) @binding(1) var inputs_p_y: array>; -@group(0) @binding(2) var inputs_q_x: array>; -@group(0) @binding(3) var inputs_q_y: array>; -@group(0) @binding(4) var outputs_x: array>; -@group(0) @binding(5) var outputs_y: array>; +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} -@compute -@workgroup_size({{ tpb }}) +@compute @workgroup_size({{ workgroup_size }}) fn main(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - let base = tid * BS; - - // Phase A — descending suffix-product over BS pairs. - var suf: array; - var acc: BigInt; - for (var jj: u32 = 0u; jj < BS; jj = jj + 1u) { - let k = BS - 1u - jj; - let idx = base + k; - var p_x: BigInt = field_load_ro(idx, &inputs_p_x); - var q_x: BigInt = field_load_ro(idx, &inputs_q_x); - var dx: BigInt = fr_sub(&q_x, &p_x); - if (jj == 0u) { + let N = params.x; + let T = params.y; + let t = gid.x; + if (t >= T) { return; } + + // Plane bases in vec4 units. Each plane spans PG*N vec4. + let plane = PG * N; + let ax_base = 0u * plane; + let ay_base = 1u * plane; + let px_base = 2u * plane; + let py_base = 3u * plane; + + // Resident accumulator A.x stays in registers across the whole + // chunk (drives the forward dx prefix chain). A.y is only needed in + // the backward peel and is re-loaded there from the same SoA plane. + var acc_x = load_be_packed(ax_base, t, N); + + // Forward pass: running prefix-product of the S dx values + // dx_i = P_i.x - A_i.x. A_i is the prefix accumulator (resident). + var pref: array; + var acc: BigInt = get_r(); + for (var i = 0u; i < S; i = i + 1u) { + let e = t + i * T; + var p_x = load_be_packed(px_base, e, N); + var dx = fr_sub(&p_x, &acc_x); + if (i == 0u) { acc = dx; } else { acc = montgomery_product(&acc, &dx); } - suf[k] = acc; + pref[i] = acc; + // Resident accumulator advances along the streamed chain: + // A_0 is the seed, A_{i+1} := P_i. Points are independent + // (P_i.x != A_i.x) so every dx is a well-defined nonzero + // difference. inv_dx is deferred to the backward pass (ONE + // fr_inv_by_a per chunk of S); A stays in registers throughout. + acc_x = p_x; } - // Phase B — single inversion per thread. - var inv: BigInt = fr_inv_by_a(suf[0]); + var inv: BigInt = fr_inv_by_a(acc); - // Phase C — ascending lean apply. - for (var k: u32 = 0u; k < BS; k = k + 1u) { - let idx = base + k; - var p_x: BigInt = field_load_ro(idx, &inputs_p_x); - var p_y: BigInt = field_load_ro(idx, &inputs_p_y); - var q_x: BigInt = field_load_ro(idx, &inputs_q_x); - var q_y: BigInt = field_load_ro(idx, &inputs_q_y); + // Backward peel + lean affine formula (dx recomputed free). + for (var jj = 0u; jj < S; jj = jj + 1u) { + let i = S - 1u - jj; + let e = t + i * T; + var p_x = load_be_packed(px_base, e, N); + var p_y = load_be_packed(py_base, e, N); - var inv_dx: BigInt; - if (k + 1u < BS) { - var sp = suf[k + 1u]; - inv_dx = montgomery_product(&inv, &sp); + // A_i (left operand): A_0 is the seed, A_i = P_{i-1} for i>0 + // (matches the forward acc_x recurrence; points independent so + // dx = P_i.x - A_i.x is always well-defined and nonzero). + var a_x: BigInt; + var a_y: BigInt; + if (i == 0u) { + a_x = load_be_packed(ax_base, t, N); + a_y = load_be_packed(ay_base, t, N); } else { + let ep = t + (i - 1u) * T; + a_x = load_be_packed(px_base, ep, N); + a_y = load_be_packed(py_base, ep, N); + } + + var inv_dx: BigInt; + if (i == 0u) { inv_dx = inv; + } else { + var pp = pref[i - 1u]; + inv_dx = montgomery_product(&inv, &pp); } - var dy: BigInt = fr_sub(&q_y, &p_y); - var lambda: BigInt = montgomery_product(&dy, &inv_dx); - var r_x: BigInt = montgomery_product(&lambda, &lambda); + var lambda = fr_sub(&p_y, &a_y); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &a_x); r_x = fr_sub(&r_x, &p_x); - r_x = fr_sub(&r_x, &q_x); - var dxb: BigInt = fr_sub(&p_x, &r_x); - var r_y: BigInt = montgomery_product(&lambda, &dxb); - r_y = fr_sub(&r_y, &p_y); + var r_y = fr_sub(&a_x, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &a_y); - field_store(idx, &outputs_x, &r_x); - field_store(idx, &outputs_y, &r_y); + store_be_packed(0u * plane, e, N, &r_x); + store_be_packed(1u * plane, e, N, &r_y); - if (k + 1u < BS) { - var dxf: BigInt = fr_sub(&q_x, &p_x); - inv = montgomery_product(&inv, &dxf); + if (i != 0u) { + var dx_back = fr_sub(&p_x, &a_x); + inv = montgomery_product(&inv, &dx_back); } } - - {{{ recompile }}} } From 8bebbf777914c58ce91f5ee02fb336e3a1c92aa1 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 19 May 2026 17:20:29 +0000 Subject: [PATCH 08/33] fix(bb/msm): reword ba_rev_packed_carry header so mustache tags aren't expanded in-comment The recovered template's header comment literally contained the {{{ dec_unpack }}} / {{{ dec_pack }}} tag text; mustache expanded them inside the // comment block, spilling the unpack256_to_limbs function body out of the comment (WGSL: statement found outside of function body). Reworded to name the injected helpers instead. --- .../ts/src/msm_webgpu/wgsl/_generated/shaders.ts | 11 ++++++----- .../wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl | 11 ++++++----- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index 22e0f7916065..50818338cd15 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1379,11 +1379,12 @@ export const ba_rev_packed_carry_bench = `{{> structs }} // (32 bytes/elem == 2x vec4), not the 20x 13-bit-limb BigInt // (80 bytes/elem == 5x vec4). Unpack into 20x13-bit limbs only // in-register at load and repack on store. The pack/unpack is the -// decoupled (\`{{{ dec_unpack }}}\` / \`{{{ dec_pack }}}\`) full-ILP -// straight-line form: 20 mutually-independent compile-time-constant- -// indexed limb extractions, zero loop-carried bit-cursor dependency -// chain. This cuts global traffic 2.5x (the dominant cost in the -// memory-bound batch-affine kernel) at a sub-cycle in-register cost. +// decoupled full-ILP straight-line form (injected below as +// unpack256_to_limbs / pack_limbs_to_256): 20 mutually-independent +// compile-time-constant-indexed limb extractions, zero loop-carried +// bit-cursor dependency chain. This cuts global traffic 2.5x (the +// dominant cost in the memory-bound batch-affine kernel) at a +// sub-cycle in-register cost. // // LAYOUT: packed elem = 2 vec4; for each of the 4 input planes // (A.x, A.y, P.x, P.y) and 2 output planes (R.x, R.y), plane c holds diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl index 9ea7af05849e..f0f6a7206649 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl @@ -26,11 +26,12 @@ // (32 bytes/elem == 2x vec4), not the 20x 13-bit-limb BigInt // (80 bytes/elem == 5x vec4). Unpack into 20x13-bit limbs only // in-register at load and repack on store. The pack/unpack is the -// decoupled (`{{{ dec_unpack }}}` / `{{{ dec_pack }}}`) full-ILP -// straight-line form: 20 mutually-independent compile-time-constant- -// indexed limb extractions, zero loop-carried bit-cursor dependency -// chain. This cuts global traffic 2.5x (the dominant cost in the -// memory-bound batch-affine kernel) at a sub-cycle in-register cost. +// decoupled full-ILP straight-line form (injected below as +// unpack256_to_limbs / pack_limbs_to_256): 20 mutually-independent +// compile-time-constant-indexed limb extractions, zero loop-carried +// bit-cursor dependency chain. This cuts global traffic 2.5x (the +// dominant cost in the memory-bound batch-affine kernel) at a +// sub-cycle in-register cost. // // LAYOUT: packed elem = 2 vec4; for each of the 4 input planes // (A.x, A.y, P.x, P.y) and 2 output planes (R.x, R.y), plane c holds From 7de03b1d54a67f55f21c7ef3864d548990a4bec0 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 19 May 2026 18:23:52 +0000 Subject: [PATCH 09/33] =?UTF-8?q?feat(bb/msm):=20bench-msm-chain=20?= =?UTF-8?q?=E2=80=94=20marshal+chain=20pair-tree=20level-0=20pipeline?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Standalone WebGPU bench wiring the recovered ba_rev_packed_carry chain kernel into a Pippenger MSM bucket-accumulate pair-tree, measuring the first reduction level (N input points -> N/2 useful pair sums). Three pieces: - ba_marshal_chain_bench.template.wgsl: pure memory-shuffle kernel. Reads a CSR-sorted point index list + chunk plan, gathers packed point coords from an SoA-packed pool, writes them into the strided SoA layout (4 planes, PG=2 vec4/elem, T*S elements per plane, element index e = t + i*T) the chain kernel consumes. point_pool[0] is reserved as a universal decoy seed so the chain kernel's first dx per chunk is well-defined; csr_indices values are 1-based. - shader_manager.gen_ba_marshal_chain_shader(workgroup_size, s). - dev/msm-webgpu/bench-msm-chain.{ts,html}: host harness. - Synthetic CSR generator: N points uniformly assigned to B buckets, sorted; row-pointer offsets + per-bucket counts. With N=131072, B=8192 the average bucket size is 16, matching S=16 sweet spot. - Chunk plan: for each bucket with count >= S, emit floor(count/S) chunks of S consecutive points (tail points with count mod S are skipped in v1 and reported as `tail_points`). - Two-stage timed bench: marshal_ms via DISP=8 back-to-back dispatch loop, then chain_ms separately. Reports marshal_ns_per_pt, chain_ns_per_pt, combined_ns_per_pt, density = T*S/PAIRS. - Sweep S in {16, 32, 64} at fixed TPB=64. - Sanity: readNonZero on chain output (first packed elem of R.x). Scope of v1 measures pair-tree level 0 only. Follow-ons not in this PR: recursive level-1..log2(S) reduce passes (same chain kernel applied to shrinking workload, ~25 ns/pt per level), tail-bucket handling for count + + + + MSM bucket-accumulate pair-tree level-0 bench (marshal+chain) + + + +

MSM bucket-accumulate pair-tree level-0 bench (marshal + chain)

+

Query params: ?reps=R&pairs=N&buckets=B&s=S&wgi=W&disp=D

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-chain.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-chain.ts new file mode 100644 index 000000000000..b644b8ab1598 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-chain.ts @@ -0,0 +1,625 @@ +/// +// bench-msm-chain — Standalone WebGPU bench for the marshal + chain +// pair-tree level-0 pipeline that integrates the ba_rev_packed_carry +// chain kernel into MSM bucket accumulate. +// +// The harness: +// 1. Generates a point pool of N+1 random Montgomery points in the +// SoA-packed layout (2 planes, PG=2 vec4/elem). Index 0 is a +// universal decoy seed; indices 1..N are real points. +// 2. Generates a synthetic CSR: B buckets, each point assigned to a +// uniformly random bucket. csr_indices = points sorted by bucket; +// offset[] = CSR row pointers; count[b] = bucket b's point count. +// 3. Builds a chunk plan from the dense slices: for each bucket b +// with count[b] >= S, splits it into floor(count[b]/S) chunks of +// exactly S points. Total dense chunks T_dense. +// 4. Runs the marshal kernel: T_dense threads, each gathers S point +// coords from the pool into the strided chain layout. +// 5. Runs the recovered ba_rev_packed_carry chain kernel on the +// marshaled layout. +// 6. Times marshal and chain separately (DISP back-to-back dispatches +// per timed sample, amortising submit + drain). +// +// Reported metrics, sweeping S in {16, 32, 64}: +// marshal_ns_per_pt — ns per input point processed by marshal +// chain_ns_per_pt — ns per input point processed by chain +// combined_ns_per_pt — marshal + chain +// density — fraction of N points covered by dense chunks +// +// Out of scope here (covered by follow-on passes): +// - Reduction passes that fold pair-tree levels 1..log2(S)/2 into +// per-bucket totals. These reuse the same chain kernel with +// decreasing T; cost is ~25 ns/pt per level, ~3 levels for S=16. +// - Tail handling for buckets with count < S. Will reuse the +// workgroup-scan batch-affine path or a variable-length variant. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { makeResultsClient } from './results_post.js'; + +const PG = 2; +const DEFAULT_PAIRS = 1 << 17; // 131072 +const DEFAULT_BUCKETS = 1 << 13; // 8192 -> avg bucket size = 16 +const DEFAULT_WGI = 64; +const DEFAULT_DISP = 8; +const DEFAULT_S_SWEEP: readonly number[] = [16, 32, 64]; + +let PAIRS = DEFAULT_PAIRS; +let BUCKETS = DEFAULT_BUCKETS; +let WGI = DEFAULT_WGI; +let DISP = DEFAULT_DISP; +let S_SWEEP: readonly number[] = DEFAULT_S_SWEEP; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + for (;;) { + let v = 0n; + for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff); + v &= (1n << BigInt(bitlen)) - 1n; + if (v > 0n && v < p) return v; + } +} + +function bigintToPackedU32x8(v: bigint): Uint32Array { + const w = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + w[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return w; +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +// Pack a pool of M (= N+1) random Montgomery points into the SoA layout +// the marshal kernel expects: 2 planes (P.x, P.y), PG vec4 per element. +// Pool index 0 is the decoy seed. +function buildPointPool(poolSize: number, R: bigint, p: bigint, rng: () => number): Uint32Array { + const M = poolSize; + const buf = new Uint32Array(2 * PG * M * 4); + for (let e = 0; e < M; e++) { + const x = (randomBelow(p, rng) * R) % p; + const y = (randomBelow(p, rng) * R) % p; + const wx = bigintToPackedU32x8(x); + const wy = bigintToPackedU32x8(y); + for (let v = 0; v < PG; v++) { + const baseX = ((0 * PG + v) * M + e) * 4; + const baseY = ((1 * PG + v) * M + e) * 4; + buf[baseX + 0] = wx[4 * v + 0]; + buf[baseX + 1] = wx[4 * v + 1]; + buf[baseX + 2] = wx[4 * v + 2]; + buf[baseX + 3] = wx[4 * v + 3]; + buf[baseY + 0] = wy[4 * v + 0]; + buf[baseY + 1] = wy[4 * v + 1]; + buf[baseY + 2] = wy[4 * v + 2]; + buf[baseY + 3] = wy[4 * v + 3]; + } + } + return buf; +} + +// Synthetic CSR: assigns each of the N points to a random bucket in +// [0, B). Returns csr_indices (point indices sorted by bucket, values in +// [1, N+1)) and offsets[b] giving the first csr_indices position for +// bucket b. Point indices are 1-based so index 0 in the pool is reserved +// for the decoy. +function buildSyntheticCSR( + N: number, + B: number, + rng: () => number, +): { csrIndices: Uint32Array; offsets: Uint32Array; counts: Uint32Array } { + const bucket = new Uint32Array(N); + const counts = new Uint32Array(B); + for (let i = 0; i < N; i++) { + const b = rng() % B; + bucket[i] = b; + counts[b]++; + } + const offsets = new Uint32Array(B + 1); + for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b]; + const cursor = new Uint32Array(B); + const csrIndices = new Uint32Array(N); + for (let i = 0; i < N; i++) { + const b = bucket[i]; + csrIndices[offsets[b] + cursor[b]++] = i + 1; // 1-based: 0 is decoy + } + return { csrIndices, offsets, counts }; +} + +// Walk the CSR row pointers and produce the dense chunk plan. For each +// bucket b with count[b] >= S, emit floor(count[b]/S) chunks each +// pointing at S consecutive csr_indices entries. tail = sum of leftover +// points (count[b] mod S) that this v1 bench skips. +function buildChunkPlan( + offsets: Uint32Array, + counts: Uint32Array, + S: number, +): { chunkPlan: Uint32Array; T: number; tailPoints: number } { + const B = counts.length; + let T = 0; + let tail = 0; + for (let b = 0; b < B; b++) { + const c = counts[b]; + T += Math.floor(c / S); + tail += c % S; + } + const chunkPlan = new Uint32Array(2 * T); + let t = 0; + for (let b = 0; b < B; b++) { + const c = counts[b]; + const nChunks = Math.floor(c / S); + for (let k = 0; k < nChunks; k++) { + chunkPlan[2 * t + 0] = b; + chunkPlan[2 * t + 1] = offsets[b] + k * S; + t++; + } + } + return { chunkPlan, T, tailPoints: tail }; +} + +interface PerSizeResult { + s: number; + wgi: number; + pairs: number; + buckets: number; + T: number; + tail_points: number; + density: number; + disp: number; + marshal_median_ms: number; + marshal_min_ms: number; + marshal_max_ms: number; + marshal_ns_per_pt: number; + chain_median_ms: number; + chain_min_ms: number; + chain_max_ms: number; + chain_ns_per_pt: number; + combined_ns_per_pt: number; + marshal_samples_ms: number[]; + chain_samples_ms: number[]; + sanity_ok: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { reps: number; pairs: number; buckets: number; wgi: number; disp: number; s_sweep: readonly number[] } | null; + results: PerSizeResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { + state: 'boot', + params: null, + results: [], + error: null, + log: [], +}; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const resultsClient = makeResultsClient({ page: 'bench-msm-chain' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, + params: benchState.params, + results: benchState.results, + error: benchState.error, + log: benchState.log, + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-msm-chain] ${msg}`); +} + +interface PipelineInfo { + pipeline: GPUComputePipeline; + layout: GPUBindGroupLayout; +} + +async function compile( + device: GPUDevice, + code: string, + cacheKey: string, + bindLayout: GPUBindGroupLayout, +): Promise { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const msg of info.messages) { + const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`; + if (msg.type === 'error') { + console.error(line); + log('err', line); + errLines.push(line); + hasError = true; + } else { + console.warn(line); + log('warn', line); + } + } + if (hasError) { + throw new Error(`WGSL compile failed for ${cacheKey}: ${errLines.join(' | ')}`); + } + return device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [bindLayout] }), + compute: { module, entryPoint: 'main' }, + }); +} + +function chainLayout(device: GPUDevice): GPUBindGroupLayout { + return device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); +} + +function marshalLayout(device: GPUDevice): GPUBindGroupLayout { + return device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); +} + +async function timeDispatchPasses( + device: GPUDevice, + pipeline: GPUComputePipeline, + bind: GPUBindGroup, + numWgs: number, + reps: number, + passes: number, +): Promise { + // warmup + { + const enc = device.createCommandEncoder(); + for (let pIdx = 0; pIdx < passes; pIdx++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + } + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + const samples: number[] = []; + for (let r = 0; r < reps; r++) { + const enc = device.createCommandEncoder(); + for (let pIdx = 0; pIdx < passes; pIdx++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + samples.push(performance.now() - t0); + } + return samples; +} + +async function readNonZero(device: GPUDevice, buf: GPUBuffer, u32Count: number): Promise { + const bytes = u32Count * 4; + const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, bytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const u32 = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + for (let i = 0; i < u32.length; i++) if (u32[i] !== 0) return true; + return false; +} + +async function runOne( + device: GPUDevice, + sm: ShaderManager, + s: number, + reps: number, + R: bigint, + p: bigint, + seed: number, +): Promise { + log('info', `=== S=${s}: PAIRS=${PAIRS} BUCKETS=${BUCKETS} WGI=${WGI} DISP=${DISP}`); + + const rng = makeRng(seed); + const poolSize = PAIRS + 1; // index 0 reserved as decoy + const poolU32 = buildPointPool(poolSize, R, p, rng); + const { csrIndices, counts } = buildSyntheticCSR(PAIRS, BUCKETS, rng); + // Recompute offsets locally (buildSyntheticCSR returns them too, but + // we only need it inside buildChunkPlan). + const offsets = new Uint32Array(BUCKETS + 1); + for (let b = 0; b < BUCKETS; b++) offsets[b + 1] = offsets[b] + counts[b]; + + const { chunkPlan, T, tailPoints } = buildChunkPlan(offsets, counts, s); + const density = (T * s) / PAIRS; + const numWgs = Math.ceil(T / WGI); + log( + 'info', + `chunk plan: T=${T} chunks (S=${s} each) -> ${T * s} pts (${(density * 100).toFixed(1)}% of ${PAIRS}); tail=${tailPoints} pts skipped`, + ); + + if (T === 0) { + throw new Error(`S=${s}: no dense chunks (every bucket has count<${s}). Try smaller S or fewer buckets.`); + } + + // GPU buffers. + const mkSb = (size: number, copyDst: boolean, copySrc: boolean): GPUBuffer => { + let usage = GPUBufferUsage.STORAGE; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + return device.createBuffer({ size, usage }); + }; + + const poolBuf = mkSb(poolU32.byteLength, true, false); + const csrBuf = mkSb(csrIndices.byteLength, true, false); + const chunkBuf = mkSb(chunkPlan.byteLength, true, false); + const chainBytes = 4 * PG * (T * s) * 4 * 4; // 4 planes * PG vec4/elem * (T*S) elems * 4 u32/vec4 * 4 B/u32 + const chainBuf = mkSb(chainBytes, false, true); + const marshalParams = device.createBuffer({ + size: 16, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + const chainParams = device.createBuffer({ + size: 16, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + + device.queue.writeBuffer(poolBuf, 0, poolU32); + device.queue.writeBuffer(csrBuf, 0, csrIndices); + device.queue.writeBuffer(chunkBuf, 0, chunkPlan); + // marshal: params = [T, poolSize, _, _] + device.queue.writeBuffer(marshalParams, 0, new Uint32Array([T, poolSize, 0, 0])); + // chain: params = [N_chain, T, _, _] with N_chain = T*S + device.queue.writeBuffer(chainParams, 0, new Uint32Array([T * s, T, 0, 0])); + + // Compile marshal pipeline. + const marshalCode = sm.gen_ba_marshal_chain_shader(WGI, s); + log('info', `marshal shader ${marshalCode.length} chars`); + const mLayout = marshalLayout(device); + const marshalPipeline = await compile(device, marshalCode, `marshal-W${WGI}-S${s}`, mLayout); + const marshalBind = device.createBindGroup({ + layout: mLayout, + entries: [ + { binding: 0, resource: { buffer: csrBuf } }, + { binding: 1, resource: { buffer: chunkBuf } }, + { binding: 2, resource: { buffer: poolBuf } }, + { binding: 3, resource: { buffer: chainBuf } }, + { binding: 4, resource: { buffer: marshalParams } }, + ], + }); + + // Compile chain pipeline. + const chainCode = sm.gen_ba_rev_packed_carry_bench_shader(WGI, s); + log('info', `chain shader ${chainCode.length} chars`); + const cLayout = chainLayout(device); + const chainPipeline = await compile(device, chainCode, `chain-W${WGI}-S${s}`, cLayout); + const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE }); + // Chain output buffer (separate from input chain_buf since the kernel + // writes its R outputs to a 2-plane output buffer). + const chainOutBytes = 2 * PG * (T * s) * 4 * 4; + const chainOutBuf = mkSb(chainOutBytes, false, true); + const chainBind = device.createBindGroup({ + layout: cLayout, + entries: [ + { binding: 0, resource: { buffer: chainBuf } }, + { binding: 1, resource: { buffer: dummy } }, + { binding: 2, resource: { buffer: chainOutBuf } }, + { binding: 3, resource: { buffer: chainParams } }, + ], + }); + + log('info', `marshal: numWgs=${numWgs}, ${T} threads, ${T * s} pts gathered/dispatch`); + log('info', `chain : numWgs=${numWgs}, ${T} threads, S=${s} adds/thread = ${T * s} pts processed/dispatch`); + + // Marshal must run at least once before chain (chain reads chain_buf). + // Warmup is built into timeDispatchPasses. + const marshalSamples = await timeDispatchPasses(device, marshalPipeline, marshalBind, numWgs, reps, DISP); + const chainSamples = await timeDispatchPasses(device, chainPipeline, chainBind, numWgs, reps, DISP); + + const sanityOk = await readNonZero(device, chainOutBuf, 8); + + const marshalMed = median(marshalSamples); + const chainMed = median(chainSamples); + const ptsPerSample = T * s * DISP; + const marshalNsPerPt = (marshalMed * 1e6) / ptsPerSample; + const chainNsPerPt = (chainMed * 1e6) / ptsPerSample; + const combinedNsPerPt = marshalNsPerPt + chainNsPerPt; + + log( + sanityOk ? 'ok' : 'err', + `S=${s}: marshal=${marshalNsPerPt.toFixed(2)}ns/pt chain=${chainNsPerPt.toFixed(2)}ns/pt combined=${combinedNsPerPt.toFixed(2)}ns/pt density=${(density * 100).toFixed(1)}% sanity=${sanityOk ? 'OK' : 'FAIL'}`, + ); + + poolBuf.destroy(); + csrBuf.destroy(); + chunkBuf.destroy(); + chainBuf.destroy(); + chainOutBuf.destroy(); + dummy.destroy(); + marshalParams.destroy(); + chainParams.destroy(); + + return { + s, + wgi: WGI, + pairs: PAIRS, + buckets: BUCKETS, + T, + tail_points: tailPoints, + density, + disp: DISP, + marshal_median_ms: marshalMed, + marshal_min_ms: Math.min(...marshalSamples), + marshal_max_ms: Math.max(...marshalSamples), + marshal_ns_per_pt: marshalNsPerPt, + chain_median_ms: chainMed, + chain_min_ms: Math.min(...chainSamples), + chain_max_ms: Math.max(...chainSamples), + chain_ns_per_pt: chainNsPerPt, + combined_ns_per_pt: combinedNsPerPt, + marshal_samples_ms: marshalSamples, + chain_samples_ms: chainSamples, + sanity_ok: sanityOk, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + const reps = parseInt(qp.get('reps') ?? '5', 10); + if (!Number.isFinite(reps) || reps <= 0 || reps > 50) { + throw new Error(`?reps must be in (0, 50], got ${qp.get('reps')}`); + } + const pairsStr = qp.get('pairs'); + if (pairsStr !== null) { + const v = parseInt(pairsStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > (1 << 20)) { + throw new Error(`?pairs must be in (0, 2^20], got ${pairsStr}`); + } + PAIRS = v; + } + const bucketsStr = qp.get('buckets'); + if (bucketsStr !== null) { + const v = parseInt(bucketsStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > (1 << 18)) { + throw new Error(`?buckets must be in (0, 2^18], got ${bucketsStr}`); + } + BUCKETS = v; + } + const wgiStr = qp.get('wgi'); + if (wgiStr !== null) { + const v = parseInt(wgiStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > 1024) { + throw new Error(`?wgi must be in (0, 1024], got ${wgiStr}`); + } + WGI = v; + } + const dispStr = qp.get('disp'); + if (dispStr !== null) { + const v = parseInt(dispStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > 64) { + throw new Error(`?disp must be in (0, 64], got ${dispStr}`); + } + DISP = v; + } + const sStr = qp.get('s'); + if (sStr !== null) { + const list = sStr.split(',').map(v => parseInt(v, 10)); + for (const v of list) { + if (!Number.isFinite(v) || v <= 0 || v > 256) { + throw new Error(`?s entries must be in (0, 256], got ${v}`); + } + } + S_SWEEP = list; + } + return { reps, pairs: PAIRS, buckets: BUCKETS, wgi: WGI, disp: DISP, s_sweep: S_SWEEP }; +} + +async function main() { + try { + if (!('gpu' in navigator)) { + throw new Error('navigator.gpu missing — WebGPU not available'); + } + const params = parseParams(); + benchState.params = params; + log( + 'info', + `params: reps=${params.reps} pairs=${params.pairs} buckets=${params.buckets} wgi=${params.wgi} disp=${params.disp} s=[${params.s_sweep.join(',')}]`, + ); + + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + + const sm = new ShaderManager(4, PAIRS, BN254_CURVE_CONFIG, false); + + let seed = 0xc511; + for (const s of S_SWEEP) { + try { + const r = await runOne(device, sm, s, params.reps, R, p, seed); + benchState.results.push(r); + resultsClient.postProgress({ + kind: 'batch_done', + s, + marshal_ns_per_pt: r.marshal_ns_per_pt, + chain_ns_per_pt: r.chain_ns_per_pt, + combined_ns_per_pt: r.combined_ns_per_pt, + density: r.density, + sanity_ok: r.sanity_ok, + }); + seed += 0x10; + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + log('err', `S=${s} failed: ${msg} — STOPPING sweep at first failure`); + benchState.state = 'error'; + benchState.error = msg; + return; + } + } + + benchState.state = 'done'; + log('ok', 'all sizes done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + }) + .finally(() => { + postFinal().catch(() => {}); + }); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs index 0f9c87ba5947..3da23eb9b39b 100644 --- a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs +++ b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs @@ -129,6 +129,7 @@ const pageMap = { "bench-batch-affine": "/dev/msm-webgpu/bench-batch-affine.html", "bench-fused-wg-scan": "/dev/msm-webgpu/bench-fused-wg-scan.html", "bench-ba-rev-packed-carry": "/dev/msm-webgpu/bench-ba-rev-packed-carry.html", + "bench-msm-chain": "/dev/msm-webgpu/bench-msm-chain.html", "bench-smvp-tree": "/dev/msm-webgpu/bench-smvp-tree.html", sanity: "/dev/msm-webgpu/index.html", }; diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index 775ad5a76ed0..95bff6ad43fb 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -5,6 +5,7 @@ import { batch_affine_apply as batch_affine_apply_shader, batch_affine_apply_scatter as batch_affine_apply_scatter_shader, batch_affine_dispatch_args as batch_affine_dispatch_args_shader, + ba_marshal_chain_bench as ba_marshal_chain_bench_shader, ba_rev_packed_carry_bench as ba_rev_packed_carry_bench_shader, bench_batch_affine as bench_batch_affine_shader, batch_affine_finalize as batch_affine_finalize_shader, @@ -765,6 +766,22 @@ ${packLines.join('\n')} ); } + /** + * Marshal kernel for the bench-msm-chain pipeline: transposes a + * CSR-sorted point list into the strided SoA layout the + * ba_rev_packed_carry chain kernel consumes. Pure memory shuffle. + */ + public gen_ba_marshal_chain_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_marshal_chain_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + return mustache.render( + ba_marshal_chain_bench_shader, + { workgroup_size, s, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + /** * Standalone microbench for the recovered ba_rev_packed_carry kernel: * SoA-packed 8x u32 storage across 4 input planes (A.x, A.y, P.x, P.y), diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index 50818338cd15..c75462d6a8c5 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1,6 +1,6 @@ // AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT. // Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate. -// 51 shader sources inlined. +// 52 shader sources inlined. /* eslint-disable */ @@ -1351,6 +1351,99 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_marshal_chain_bench = `{{> structs }} + +// Marshal kernel for the bench-msm-chain pipeline. Transposes a CSR +// point list (sorted by bucket) into the strided SoA layout the +// ba_rev_packed_carry_bench chain kernel consumes. +// +// Input layout (point_pool): +// 2 planes (P.x, P.y), each PG=2 vec4 per element, params.y elements total. +// Plane p at point idx i: vec4 indices p*PG*N + PG*i + {0,1}. +// Convention: point_pool[0] is the "decoy" — used as the seed for every +// chunk so the chain kernel's first dx (= P_0.x - seed.x) is well- +// defined. csr_indices values are in [1, N), never 0. +// +// Output layout (chain_buf): +// 4 planes (A.x, A.y, P.x, P.y), each PG=2 vec4 per element, T*S +// elements per plane. Plane p at strided element e = t + i*T: vec4 +// indices p*PG*(T*S) + PG*e + {0,1}. +// +// Per chunk-thread t: +// - csr_start = chunk_plan[2*t + 1] (chunk_plan[2*t] = bucket_id, unused here) +// - Seed at index t (planes 0,1) := point_pool[0] (universal decoy) +// - For i in 0..S: P_i at index e = t + i*T (planes 2,3) +// := point_pool[csr_indices[csr_start + i]] +// +// The chain kernel then produces S pair-sums per chunk. The S/2 odd- +// indexed outputs (R_1, R_3, ..., R_{S-1}) are disjoint pair sums of +// {P_0..P_{S-1}}; the even outputs (R_0, R_2, ...) incorporate the +// decoy or share a P with the next odd output and are discarded by the +// subsequent reduce pass. +// +// Pure memory-shuffle kernel: no field arithmetic. Reads are coalesced +// because consecutive threads t, t+1 read adjacent csr_indices entries +// and the gathered point coords are written to adjacent vec4 slots +// (PG*e for e=t, t+1, ...). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var csr_indices: array; +@group(0) @binding(1) var chunk_plan: array; +@group(0) @binding(2) var point_pool: array>; +@group(0) @binding(3) var chain_buf: array>; +@group(0) @binding(4) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let N = params.y; + let t = gid.x; + if (t >= T) { return; } + + let csr_start = chunk_plan[2u * t + 1u]; + + let chain_N = T * S; + let chain_plane = PG * chain_N; + let chain_ax_base = 0u * chain_plane; + let chain_ay_base = 1u * chain_plane; + let chain_px_base = 2u * chain_plane; + let chain_py_base = 3u * chain_plane; + + let pool_plane = PG * N; + let pool_px_base = 0u * pool_plane; + let pool_py_base = 1u * pool_plane; + + // Seed (A.x, A.y at index t) := point_pool[0] (decoy). + let decoy_x_off = pool_px_base + PG * 0u; + let decoy_y_off = pool_py_base + PG * 0u; + let seed_x_off = chain_ax_base + PG * t; + let seed_y_off = chain_ay_base + PG * t; + chain_buf[seed_x_off + 0u] = point_pool[decoy_x_off + 0u]; + chain_buf[seed_x_off + 1u] = point_pool[decoy_x_off + 1u]; + chain_buf[seed_y_off + 0u] = point_pool[decoy_y_off + 0u]; + chain_buf[seed_y_off + 1u] = point_pool[decoy_y_off + 1u]; + + // Gather S points from csr_indices[csr_start..csr_start+S] into the + // strided P-planes at indices e = t + i*T for i in 0..S. + for (var i = 0u; i < S; i = i + 1u) { + let pt_idx = csr_indices[csr_start + i]; + let e = t + i * T; + let pool_x_off = pool_px_base + PG * pt_idx; + let pool_y_off = pool_py_base + PG * pt_idx; + let chain_px_off = chain_px_base + PG * e; + let chain_py_off = chain_py_base + PG * e; + chain_buf[chain_px_off + 0u] = point_pool[pool_x_off + 0u]; + chain_buf[chain_px_off + 1u] = point_pool[pool_x_off + 1u]; + chain_buf[chain_py_off + 0u] = point_pool[pool_y_off + 0u]; + chain_buf[chain_py_off + 1u] = point_pool[pool_y_off + 1u]; + } + + {{{ recompile }}} +} +`; + export const ba_rev_packed_carry_bench = `{{> structs }} {{> bigint_funcs }} {{> montgomery_product_funcs }} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_chain_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_chain_bench.template.wgsl new file mode 100644 index 000000000000..538be64e43a7 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_chain_bench.template.wgsl @@ -0,0 +1,91 @@ +{{> structs }} + +// Marshal kernel for the bench-msm-chain pipeline. Transposes a CSR +// point list (sorted by bucket) into the strided SoA layout the +// ba_rev_packed_carry_bench chain kernel consumes. +// +// Input layout (point_pool): +// 2 planes (P.x, P.y), each PG=2 vec4 per element, params.y elements total. +// Plane p at point idx i: vec4 indices p*PG*N + PG*i + {0,1}. +// Convention: point_pool[0] is the "decoy" — used as the seed for every +// chunk so the chain kernel's first dx (= P_0.x - seed.x) is well- +// defined. csr_indices values are in [1, N), never 0. +// +// Output layout (chain_buf): +// 4 planes (A.x, A.y, P.x, P.y), each PG=2 vec4 per element, T*S +// elements per plane. Plane p at strided element e = t + i*T: vec4 +// indices p*PG*(T*S) + PG*e + {0,1}. +// +// Per chunk-thread t: +// - csr_start = chunk_plan[2*t + 1] (chunk_plan[2*t] = bucket_id, unused here) +// - Seed at index t (planes 0,1) := point_pool[0] (universal decoy) +// - For i in 0..S: P_i at index e = t + i*T (planes 2,3) +// := point_pool[csr_indices[csr_start + i]] +// +// The chain kernel then produces S pair-sums per chunk. The S/2 odd- +// indexed outputs (R_1, R_3, ..., R_{S-1}) are disjoint pair sums of +// {P_0..P_{S-1}}; the even outputs (R_0, R_2, ...) incorporate the +// decoy or share a P with the next odd output and are discarded by the +// subsequent reduce pass. +// +// Pure memory-shuffle kernel: no field arithmetic. Reads are coalesced +// because consecutive threads t, t+1 read adjacent csr_indices entries +// and the gathered point coords are written to adjacent vec4 slots +// (PG*e for e=t, t+1, ...). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var csr_indices: array; +@group(0) @binding(1) var chunk_plan: array; +@group(0) @binding(2) var point_pool: array>; +@group(0) @binding(3) var chain_buf: array>; +@group(0) @binding(4) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let N = params.y; + let t = gid.x; + if (t >= T) { return; } + + let csr_start = chunk_plan[2u * t + 1u]; + + let chain_N = T * S; + let chain_plane = PG * chain_N; + let chain_ax_base = 0u * chain_plane; + let chain_ay_base = 1u * chain_plane; + let chain_px_base = 2u * chain_plane; + let chain_py_base = 3u * chain_plane; + + let pool_plane = PG * N; + let pool_px_base = 0u * pool_plane; + let pool_py_base = 1u * pool_plane; + + // Seed (A.x, A.y at index t) := point_pool[0] (decoy). + let decoy_x_off = pool_px_base + PG * 0u; + let decoy_y_off = pool_py_base + PG * 0u; + let seed_x_off = chain_ax_base + PG * t; + let seed_y_off = chain_ay_base + PG * t; + chain_buf[seed_x_off + 0u] = point_pool[decoy_x_off + 0u]; + chain_buf[seed_x_off + 1u] = point_pool[decoy_x_off + 1u]; + chain_buf[seed_y_off + 0u] = point_pool[decoy_y_off + 0u]; + chain_buf[seed_y_off + 1u] = point_pool[decoy_y_off + 1u]; + + // Gather S points from csr_indices[csr_start..csr_start+S] into the + // strided P-planes at indices e = t + i*T for i in 0..S. + for (var i = 0u; i < S; i = i + 1u) { + let pt_idx = csr_indices[csr_start + i]; + let e = t + i * T; + let pool_x_off = pool_px_base + PG * pt_idx; + let pool_y_off = pool_py_base + PG * pt_idx; + let chain_px_off = chain_px_base + PG * e; + let chain_py_off = chain_py_base + PG * e; + chain_buf[chain_px_off + 0u] = point_pool[pool_x_off + 0u]; + chain_buf[chain_px_off + 1u] = point_pool[pool_x_off + 1u]; + chain_buf[chain_py_off + 0u] = point_pool[pool_y_off + 0u]; + chain_buf[chain_py_off + 1u] = point_pool[pool_y_off + 1u]; + } + + {{{ recompile }}} +} From 33e57556206d5fe569604719ab294ac4176fdc03 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 19 May 2026 19:37:37 +0000 Subject: [PATCH 10/33] feat(bb/msm): disjoint pair-sum kernel + bench (closes 50% chain-kernel waste) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Each thread reduces 2*S input points to S disjoint pair sums R_k = P_{2k} + P_{2k+1} via the same forward-prefix-product / single fr_inv_by_a / backward-peel batched-inverse pattern as ba_rev_packed_carry, but with NO load-carry overlap. Every kernel output is a distinct disjoint pair sum suitable as input to the next pair-tree level. The chain kernel produced S overlapping sums of which only S/2 were usable (R_1, R_3, ..., R_{S-1}); the disjoint kernel produces S usable sums for the same per-thread inversion cost, reclaiming the 50% kernel-efficiency loss. Storage: SoA-packed 8x u32 per field (PG=2 vec4/elem). 2 input planes (P.x, P.y) of 2*S*T elements each; 2 output planes (R.x, R.y) of S*T each. dx values dx_k = P_{2k+1}.x - P_{2k}.x are mutually independent (no shared inputs across k), so the Montgomery batched inverse trick applies as-is — same Karatsuba+Yuval montmul, same BY-safegcd fr_inv_by_a, ONE inversion amortised across the S pair sums per thread. Files: - wgsl/cuzk/ba_pair_disjoint_bench.template.wgsl - gen_ba_pair_disjoint_bench_shader(workgroup_size, s) on ShaderManager - dev/msm-webgpu/bench-ba-pair-disjoint.{ts,html} host harness with DISP=8 dispatch amortisation matched to bench-ba-rev-packed-carry for apples-to-apples ns/op comparison - pageMap entry in run-browserstack.mjs Expected: ns/op ~25 ns matching the chain kernel's per-add cost. Since every output is now useful (vs S/2 useful in the chain), the ns/ useful-pair-sum metric should halve from ~50 to ~25 — the headline win for full MSM bucket-accumulate via pair-tree reduction. --- .../msm-webgpu/bench-ba-pair-disjoint.html | 37 ++ .../dev/msm-webgpu/bench-ba-pair-disjoint.ts | 435 ++++++++++++++++++ .../msm-webgpu/scripts/run-browserstack.mjs | 1 + .../ts/src/msm_webgpu/cuzk/shader_manager.ts | 46 ++ .../src/msm_webgpu/wgsl/_generated/shaders.ts | 142 +++++- .../cuzk/ba_pair_disjoint_bench.template.wgsl | 138 ++++++ 6 files changed, 798 insertions(+), 1 deletion(-) create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.html create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.ts create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_bench.template.wgsl diff --git a/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.html b/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.html new file mode 100644 index 000000000000..95dc4bee2cdb --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.html @@ -0,0 +1,37 @@ + + + + + Disjoint pair-sum batch-affine bench (WebGPU) + + + +

Disjoint pair-sum batch-affine bench (WebGPU)

+

Query params: ?reps=R&pairs=N&wgi=W&disp=D&s=A,B,C

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.ts b/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.ts new file mode 100644 index 000000000000..500fff0eba60 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.ts @@ -0,0 +1,435 @@ +/// +// Standalone WebGPU bench for the disjoint pair-sum kernel +// (ba_pair_disjoint_bench): each thread reduces 2*S input points to +// S disjoint pair sums R_k = P_{2k} + P_{2k+1}, using one batched +// fr_inv_by_a per chunk of S. Same DISP=8 dispatch amortisation as +// bench-ba-rev-packed-carry to keep the measurement methodology +// apples-to-apples. +// +// Input: random Montgomery field elems packed into SoA layout with +// 2 planes (P.x, P.y), 2*PAIRS elements per plane. Pairs are arranged +// at strided positions e = t + i*T for i in 0..2S so the kernel's +// coalesced reads work. Adjacent pair members are guaranteed distinct +// x to avoid the lean-add div-by-zero. +// +// Output: 2 planes (R.x, R.y), PAIRS elements per plane. Reports +// ns/useful-pair-sum (every kernel output is a usable disjoint pair +// sum, vs the chain kernel where only S/2 of S outputs are usable). + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { makeResultsClient } from './results_post.js'; + +const PG = 2; +const DEFAULT_PAIRS = 1 << 17; // 131072 output pair sums per dispatch +const DEFAULT_WGI = 64; +const DEFAULT_DISP = 8; +const DEFAULT_S_SWEEP: readonly number[] = [16, 32, 64]; + +let PAIRS = DEFAULT_PAIRS; +let WGI = DEFAULT_WGI; +let DISP = DEFAULT_DISP; +let S_SWEEP: readonly number[] = DEFAULT_S_SWEEP; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + for (;;) { + let v = 0n; + for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff); + v &= (1n << BigInt(bitlen)) - 1n; + if (v > 0n && v < p) return v; + } +} + +function bigintToPackedU32x8(v: bigint): Uint32Array { + const w = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + w[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return w; +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +// SoA-packed input buffer with 2 planes (P.x, P.y), each PG*(2*PAIRS) +// vec4. Plane p at element idx e: vec4 indices (p*PG + v)*N_in + e for +// v in 0..PG, where N_in = 2*PAIRS. Adjacent pairs (2k, 2k+1) have +// distinct x. +function buildPackedPairsSoA(pairs: number, R: bigint, p: bigint, rng: () => number): Uint32Array { + const N_in = 2 * pairs; + const buf = new Uint32Array(2 * PG * N_in * 4); + for (let k = 0; k < pairs; k++) { + let lx: bigint; + let rx: bigint; + do { + lx = (randomBelow(p, rng) * R) % p; + rx = (randomBelow(p, rng) * R) % p; + } while (lx === rx); + const ly = (randomBelow(p, rng) * R) % p; + const ry = (randomBelow(p, rng) * R) % p; + const writeElem = (planeIdx: number, e: number, val: bigint) => { + const words = bigintToPackedU32x8(val); + for (let v = 0; v < PG; v++) { + const base = ((planeIdx * PG + v) * N_in + e) * 4; + buf[base + 0] = words[4 * v + 0]; + buf[base + 1] = words[4 * v + 1]; + buf[base + 2] = words[4 * v + 2]; + buf[base + 3] = words[4 * v + 3]; + } + }; + writeElem(0, 2 * k + 0, lx); + writeElem(1, 2 * k + 0, ly); + writeElem(0, 2 * k + 1, rx); + writeElem(1, 2 * k + 1, ry); + } + return buf; +} + +interface PerSizeResult { + s: number; + wgi: number; + T: number; + num_wgs: number; + pairs: number; + disp: number; + total_ops: number; + median_ms: number; + min_ms: number; + max_ms: number; + ns_per_op: number; + samples_ms: number[]; + sanity_ok: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { reps: number; pairs: number; wgi: number; disp: number; s_sweep: readonly number[] } | null; + results: PerSizeResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { + state: 'boot', + params: null, + results: [], + error: null, + log: [], +}; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const resultsClient = makeResultsClient({ page: 'bench-ba-pair-disjoint' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, + params: benchState.params, + results: benchState.results, + error: benchState.error, + log: benchState.log, + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-ba-pair-disjoint] ${msg}`); +} + +async function compile( + device: GPUDevice, + code: string, + cacheKey: string, +): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const msg of info.messages) { + const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`; + if (msg.type === 'error') { + console.error(line); + log('err', line); + errLines.push(line); + hasError = true; + } else { + console.warn(line); + log('warn', line); + } + } + if (hasError) throw new Error(`WGSL compile failed for ${cacheKey}: ${errLines.join(' | ')}`); + const layout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const pipeline = await device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); + return { pipeline, layout }; +} + +async function readNonZero(device: GPUDevice, buf: GPUBuffer, u32Count: number): Promise { + const bytes = u32Count * 4; + const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, bytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const u32 = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + for (let i = 0; i < u32.length; i++) if (u32[i] !== 0) return true; + return false; +} + +async function timeDispatch( + device: GPUDevice, + pipeline: GPUComputePipeline, + bind: GPUBindGroup, + numWgs: number, + reps: number, + passes: number, +): Promise { + { + const enc = device.createCommandEncoder(); + for (let pIdx = 0; pIdx < passes; pIdx++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + } + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + const samples: number[] = []; + for (let r = 0; r < reps; r++) { + const enc = device.createCommandEncoder(); + for (let pIdx = 0; pIdx < passes; pIdx++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + samples.push(performance.now() - t0); + } + return samples; +} + +async function runOne( + device: GPUDevice, + sm: ShaderManager, + s: number, + reps: number, + R: bigint, + p: bigint, + seed: number, +): Promise { + if (PAIRS % s !== 0) throw new Error(`PAIRS=${PAIRS} must be a multiple of S=${s}`); + const T = PAIRS / s; + const numWgs = Math.ceil(T / WGI); + log('info', `=== S=${s}: PAIRS=${PAIRS} T=${T} WGI=${WGI} numWgs=${numWgs} DISP=${DISP}`); + + const code = sm.gen_ba_pair_disjoint_bench_shader(WGI, s); + log('info', `compiling shader (${code.length} chars)`); + (window as unknown as Record)[`__shader_s${s}`] = code; + const cacheKey = `bench-ba-pair-disjoint-W${WGI}-S${s}`; + const { pipeline, layout } = await compile(device, code, cacheKey); + + const rng = makeRng(seed); + const inU32 = buildPackedPairsSoA(PAIRS, R, p, rng); + + const inBuf = device.createBuffer({ + size: inU32.byteLength, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(inBuf, 0, inU32); + const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE }); + const outBytes = 2 * PG * PAIRS * 4 * 4; + const outBuf = device.createBuffer({ + size: outBytes, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, + }); + const paramsBuf = device.createBuffer({ + size: 16, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(paramsBuf, 0, new Uint32Array([2 * PAIRS, T, 0, 0])); + + const bind = device.createBindGroup({ + layout, + entries: [ + { binding: 0, resource: { buffer: inBuf } }, + { binding: 1, resource: { buffer: dummy } }, + { binding: 2, resource: { buffer: outBuf } }, + { binding: 3, resource: { buffer: paramsBuf } }, + ], + }); + + const samples = await timeDispatch(device, pipeline, bind, numWgs, reps, DISP); + const sanityOk = await readNonZero(device, outBuf, 8); + const med = median(samples); + const totalOps = PAIRS * DISP; + const nsPerOp = (med * 1e6) / totalOps; + + log( + sanityOk ? 'ok' : 'err', + `S=${s}: median=${med.toFixed(3)}ms min=${Math.min(...samples).toFixed(3)}ms max=${Math.max(...samples).toFixed(3)}ms ns/op=${nsPerOp.toFixed(2)} sanity=${sanityOk ? 'OK' : 'FAIL'}`, + ); + + inBuf.destroy(); + dummy.destroy(); + outBuf.destroy(); + paramsBuf.destroy(); + + return { + s, + wgi: WGI, + T, + num_wgs: numWgs, + pairs: PAIRS, + disp: DISP, + total_ops: totalOps, + median_ms: med, + min_ms: Math.min(...samples), + max_ms: Math.max(...samples), + ns_per_op: nsPerOp, + samples_ms: samples, + sanity_ok: sanityOk, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + const reps = parseInt(qp.get('reps') ?? '5', 10); + if (!Number.isFinite(reps) || reps <= 0 || reps > 50) throw new Error(`?reps must be in (0, 50]`); + const pairsStr = qp.get('pairs'); + if (pairsStr !== null) { + const v = parseInt(pairsStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > (1 << 20)) throw new Error(`?pairs must be in (0, 2^20]`); + PAIRS = v; + } + const wgiStr = qp.get('wgi'); + if (wgiStr !== null) { + const v = parseInt(wgiStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > 1024) throw new Error(`?wgi must be in (0, 1024]`); + WGI = v; + } + const dispStr = qp.get('disp'); + if (dispStr !== null) { + const v = parseInt(dispStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > 64) throw new Error(`?disp must be in (0, 64]`); + DISP = v; + } + const sStr = qp.get('s'); + if (sStr !== null) { + const list = sStr.split(',').map(v => parseInt(v, 10)); + for (const v of list) { + if (!Number.isFinite(v) || v <= 0 || v > 256) throw new Error(`?s entries must be in (0, 256]`); + } + S_SWEEP = list; + } + for (const v of S_SWEEP) { + if (PAIRS % v !== 0) throw new Error(`S=${v} does not divide PAIRS=${PAIRS}`); + } + return { reps, pairs: PAIRS, wgi: WGI, disp: DISP, s_sweep: S_SWEEP }; +} + +async function main() { + try { + if (!('gpu' in navigator)) throw new Error('navigator.gpu missing — WebGPU not available'); + const params = parseParams(); + benchState.params = params; + log( + 'info', + `params: reps=${params.reps} pairs=${params.pairs} wgi=${params.wgi} disp=${params.disp} s=[${params.s_sweep.join(',')}]`, + ); + + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + + const sm = new ShaderManager(4, PAIRS, BN254_CURVE_CONFIG, false); + + let seed = 0xd1d1; + for (const s of S_SWEEP) { + try { + const r = await runOne(device, sm, s, params.reps, R, p, seed); + benchState.results.push(r); + resultsClient.postProgress({ + kind: 'batch_done', + s, + median_ms: r.median_ms, + ns_per_op: r.ns_per_op, + sanity_ok: r.sanity_ok, + }); + seed += 0x10; + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + log('err', `S=${s} failed: ${msg} — STOPPING`); + benchState.state = 'error'; + benchState.error = msg; + return; + } + } + + benchState.state = 'done'; + log('ok', 'all sizes done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + }) + .finally(() => { + postFinal().catch(() => {}); + }); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs index 3da23eb9b39b..be4d0fba17d1 100644 --- a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs +++ b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs @@ -130,6 +130,7 @@ const pageMap = { "bench-fused-wg-scan": "/dev/msm-webgpu/bench-fused-wg-scan.html", "bench-ba-rev-packed-carry": "/dev/msm-webgpu/bench-ba-rev-packed-carry.html", "bench-msm-chain": "/dev/msm-webgpu/bench-msm-chain.html", + "bench-ba-pair-disjoint": "/dev/msm-webgpu/bench-ba-pair-disjoint.html", "bench-smvp-tree": "/dev/msm-webgpu/bench-smvp-tree.html", sanity: "/dev/msm-webgpu/index.html", }; diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index 95bff6ad43fb..0c1a093b461d 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -6,6 +6,7 @@ import { batch_affine_apply_scatter as batch_affine_apply_scatter_shader, batch_affine_dispatch_args as batch_affine_dispatch_args_shader, ba_marshal_chain_bench as ba_marshal_chain_bench_shader, + ba_pair_disjoint_bench as ba_pair_disjoint_bench_shader, ba_rev_packed_carry_bench as ba_rev_packed_carry_bench_shader, bench_batch_affine as bench_batch_affine_shader, batch_affine_finalize as batch_affine_finalize_shader, @@ -782,6 +783,51 @@ ${packLines.join('\n')} ); } + /** + * Standalone microbench for the disjoint pair-sum kernel: each + * thread reduces 2*S input points to S disjoint pair sums R_k = + * P_{2k} + P_{2k+1}, using one batched fr_inv_by_a per chunk of S. + * Reclaims the 50% kernel-efficiency loss inherent in + * ba_rev_packed_carry (which produces S overlapping pair sums of + * which only S/2 are usable for pair-tree reduction). + */ + public gen_ba_pair_disjoint_bench_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_pair_disjoint_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + const dec = this.decoupledPackUnpackWgsl(); + return mustache.render( + ba_pair_disjoint_bench_shader, + { + workgroup_size, + s, + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, + dec_unpack: dec.unpack, + dec_pack: dec.pack, + recompile: this.recompile, + }, + { + structs, + bigint_funcs, + montgomery_product_funcs: this.mont_product_src, + field_funcs, + fr_pow_funcs, + bigint_by_funcs, + by_inverse_a_funcs, + }, + ); + } + /** * Standalone microbench for the recovered ba_rev_packed_carry kernel: * SoA-packed 8x u32 storage across 4 input planes (A.x, A.y, P.x, P.y), diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index c75462d6a8c5..43910d8fb1a1 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1,6 +1,6 @@ // AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT. // Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate. -// 52 shader sources inlined. +// 53 shader sources inlined. /* eslint-disable */ @@ -1444,6 +1444,146 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_pair_disjoint_bench = `{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Disjoint pair-sum kernel — each thread reduces 2*S input points to S +// disjoint pair sums R_k = P_{2k} + P_{2k+1} (k in 0..S) using the +// same forward-prefix / single-inversion / backward-peel batched- +// inverse pattern as ba_rev_packed_carry, but with NO load-carry +// overlap. Every kernel-output is a distinct pair sum suitable as +// input to the next level of a pair-tree reduction — closes the 50% +// kernel-efficiency loss inherent in the streaming chain kernel. +// +// Storage: SoA-packed 8x u32 per field (PG=2 vec4/elem). +// Input planes (binding 0): +// plane 0 (P.x): PG * N_in vec4, N_in = 2*S*T +// plane 1 (P.y): PG * N_in vec4 +// Output planes (binding 2): +// plane 0 (R.x): PG * N_out vec4, N_out = S*T +// plane 1 (R.y): PG * N_out vec4 +// +// Thread t reads P_i = (inp[plane c at index t + i*T] : c in {0,1}) for +// i in 0..2S (strided => coalesced). Pair k pairs adjacent strided +// slots: (P_{2k}, P_{2k+1}). Output R_k is written at index t + k*T in +// plane c of outp (also strided, coalesced). +// +// dx values dx_k = P_{2k+1}.x - P_{2k}.x are all mutually independent +// (no shared inputs across k), so the standard Montgomery batched +// inverse trick applies as-is: ONE fr_inv_by_a per chunk of S. +// +// Same Karatsuba+Yuval montmul and BY-safegcd fr_inv_by_a as the +// production stack and the chain kernel. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var inp: array>; +@group(0) @binding(1) var unused: array>; +@group(0) @binding(2) var outp: array>; +@group(0) @binding(3) var params: vec4; + +fn load_in(plane: u32, t: u32, i: u32, T: u32, N_in: u32) -> BigInt { + let plane_base = plane * PG * N_in; + let base = plane_base + PG * (t + i * T); + let q0 = inp[base + 0u]; + let q1 = inp[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_out(plane: u32, t: u32, k: u32, T: u32, N_out: u32, val: ptr) { + let plane_base = plane * PG * N_out; + let base = plane_base + PG * (t + k * T); + let w = pack_limbs_to_256(val); + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let N_in = params.x; + let T = params.y; + let N_out = N_in / 2u; + + let t = gid.x; + if (t >= T) { return; } + + // Forward: prefix product of S independent dx values. + var pref: array; + var acc: BigInt = get_r(); + for (var k: u32 = 0u; k < S; k = k + 1u) { + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T, N_in); + var dx: BigInt = fr_sub(&p_rx, &p_lx); + if (k == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[k] = acc; + } + + // One BY-safegcd inversion amortised over all S pair sums. + var inv: BigInt = fr_inv_by_a(acc); + + // Backward peel: emit S disjoint pair sums. + for (var jj: u32 = 0u; jj < S; jj = jj + 1u) { + let k = S - 1u - jj; + + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T, N_in); + var p_ly: BigInt = load_in(1u, t, 2u * k + 0u, T, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T, N_in); + var p_ry: BigInt = load_in(1u, t, 2u * k + 1u, T, N_in); + + var inv_dx: BigInt; + if (k == 0u) { + inv_dx = inv; + } else { + var pp = pref[k - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda: BigInt = fr_sub(&p_ry, &p_ly); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_lx); + r_x = fr_sub(&r_x, &p_rx); + var r_y: BigInt = fr_sub(&p_lx, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &p_ly); + + store_out(0u, t, k, T, N_out, &r_x); + store_out(1u, t, k, T, N_out, &r_y); + + // Advance inv to 1/pref[k-1] for the next (smaller) iteration. + if (k > 0u) { + var dx_back: BigInt = fr_sub(&p_rx, &p_lx); + inv = montgomery_product(&inv, &dx_back); + } + } + + {{{ recompile }}} +} +`; + export const ba_rev_packed_carry_bench = `{{> structs }} {{> bigint_funcs }} {{> montgomery_product_funcs }} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_bench.template.wgsl new file mode 100644 index 000000000000..d5a83e646f1b --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_bench.template.wgsl @@ -0,0 +1,138 @@ +{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Disjoint pair-sum kernel — each thread reduces 2*S input points to S +// disjoint pair sums R_k = P_{2k} + P_{2k+1} (k in 0..S) using the +// same forward-prefix / single-inversion / backward-peel batched- +// inverse pattern as ba_rev_packed_carry, but with NO load-carry +// overlap. Every kernel-output is a distinct pair sum suitable as +// input to the next level of a pair-tree reduction — closes the 50% +// kernel-efficiency loss inherent in the streaming chain kernel. +// +// Storage: SoA-packed 8x u32 per field (PG=2 vec4/elem). +// Input planes (binding 0): +// plane 0 (P.x): PG * N_in vec4, N_in = 2*S*T +// plane 1 (P.y): PG * N_in vec4 +// Output planes (binding 2): +// plane 0 (R.x): PG * N_out vec4, N_out = S*T +// plane 1 (R.y): PG * N_out vec4 +// +// Thread t reads P_i = (inp[plane c at index t + i*T] : c in {0,1}) for +// i in 0..2S (strided => coalesced). Pair k pairs adjacent strided +// slots: (P_{2k}, P_{2k+1}). Output R_k is written at index t + k*T in +// plane c of outp (also strided, coalesced). +// +// dx values dx_k = P_{2k+1}.x - P_{2k}.x are all mutually independent +// (no shared inputs across k), so the standard Montgomery batched +// inverse trick applies as-is: ONE fr_inv_by_a per chunk of S. +// +// Same Karatsuba+Yuval montmul and BY-safegcd fr_inv_by_a as the +// production stack and the chain kernel. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var inp: array>; +@group(0) @binding(1) var unused: array>; +@group(0) @binding(2) var outp: array>; +@group(0) @binding(3) var params: vec4; + +fn load_in(plane: u32, t: u32, i: u32, T: u32, N_in: u32) -> BigInt { + let plane_base = plane * PG * N_in; + let base = plane_base + PG * (t + i * T); + let q0 = inp[base + 0u]; + let q1 = inp[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_out(plane: u32, t: u32, k: u32, T: u32, N_out: u32, val: ptr) { + let plane_base = plane * PG * N_out; + let base = plane_base + PG * (t + k * T); + let w = pack_limbs_to_256(val); + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let N_in = params.x; + let T = params.y; + let N_out = N_in / 2u; + + let t = gid.x; + if (t >= T) { return; } + + // Forward: prefix product of S independent dx values. + var pref: array; + var acc: BigInt = get_r(); + for (var k: u32 = 0u; k < S; k = k + 1u) { + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T, N_in); + var dx: BigInt = fr_sub(&p_rx, &p_lx); + if (k == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[k] = acc; + } + + // One BY-safegcd inversion amortised over all S pair sums. + var inv: BigInt = fr_inv_by_a(acc); + + // Backward peel: emit S disjoint pair sums. + for (var jj: u32 = 0u; jj < S; jj = jj + 1u) { + let k = S - 1u - jj; + + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T, N_in); + var p_ly: BigInt = load_in(1u, t, 2u * k + 0u, T, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T, N_in); + var p_ry: BigInt = load_in(1u, t, 2u * k + 1u, T, N_in); + + var inv_dx: BigInt; + if (k == 0u) { + inv_dx = inv; + } else { + var pp = pref[k - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda: BigInt = fr_sub(&p_ry, &p_ly); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_lx); + r_x = fr_sub(&r_x, &p_rx); + var r_y: BigInt = fr_sub(&p_lx, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &p_ly); + + store_out(0u, t, k, T, N_out, &r_x); + store_out(1u, t, k, T, N_out, &r_y); + + // Advance inv to 1/pref[k-1] for the next (smaller) iteration. + if (k > 0u) { + var dx_back: BigInt = fr_sub(&p_rx, &p_lx); + inv = montgomery_product(&inv, &dx_back); + } + } + + {{{ recompile }}} +} From dcd129cf96eeaca17dc158f039984a279f9e3d0f Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 19 May 2026 23:48:12 +0000 Subject: [PATCH 11/33] feat(bb/msm): complete pair-tree MSM bucket-accumulate replacement MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three new kernels that compose into a full replacement for the cuZK bucket-accumulate phase of Pippenger MSM: 1. ba_pair_disjoint_tree — tree variant of the disjoint pair-sum kernel. Writes outputs in the LAYOUT THE NEXT PAIR-TREE LEVEL EXPECTS, so multi-level reductions chain with no intervening marshal pass: out_pos(t, k) = (t >> 1) + (k + S * (t & 1)) * (T >> 1) Final-level flag (params.z) switches to a simple strided write. 2. ba_marshal_tree_l0 — CSR -> level-0 strided 2-plane input layout. Pure memory shuffle; same chunk-plan + CSR pattern as ba_marshal_chain_bench but 2-plane output (no decoy seed). 3. ba_tail_reduce — handles small buckets (count < 2*S). One thread per tail bucket, serial chain with one fr_inv_by_a per add. v1 pragmatic design with no batched inversion across threads; closes correctness for buckets that can't use the disjoint kernel. Host harness bench-msm-tree.{ts,html} drives the full pipeline: marshal-l0 -> tree-disjoint level 0 -> level 1 -> ... -> final tail Reports per-stage and combined ns/in-pt over total points. Modes: ?mode=uniform : every bucket has exactly 2*S = 32 points (clean multi-level test, no tail). ?mode=skewed : Poisson via uniform random scalar assignment, both main path and tail exercised. shader_manager.ts: - gen_ba_pair_disjoint_tree_bench_shader(workgroup_size, s) - gen_ba_marshal_tree_l0_bench_shader(workgroup_size, s) - gen_ba_tail_reduce_bench_shader(workgroup_size, s) (TAIL_CAP = 2*S - 1) run-browserstack.mjs: pageMap entry "bench-msm-tree". This is the complete kernel set for plugging the disjoint pair-sum batch-affine design into the production MSM bucket-accumulate phase. --- .../ts/dev/msm-webgpu/bench-msm-tree.html | 24 + .../ts/dev/msm-webgpu/bench-msm-tree.ts | 769 ++++++++++++++++++ .../msm-webgpu/scripts/run-browserstack.mjs | 1 + .../ts/src/msm_webgpu/cuzk/shader_manager.ts | 107 +++ .../src/msm_webgpu/wgsl/_generated/shaders.ts | 353 +++++++- .../ba_marshal_tree_l0_bench.template.wgsl | 64 ++ .../ba_pair_disjoint_tree_bench.template.wgsl | 169 ++++ .../cuzk/ba_tail_reduce_bench.template.wgsl | 112 +++ 8 files changed, 1598 insertions(+), 1 deletion(-) create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-msm-tree.html create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_tree_l0_bench.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_tree_bench.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_tail_reduce_bench.template.wgsl diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.html new file mode 100644 index 000000000000..f64e6d3ff15a --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.html @@ -0,0 +1,24 @@ + + + + + MSM bucket-accumulate multi-level pair-tree bench (WebGPU) + + + +

MSM bucket-accumulate multi-level pair-tree bench (WebGPU)

+

Query params: ?reps=R&n=N&buckets=B&s=S&wgi=W&mode=uniform|skewed&disp=D

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts new file mode 100644 index 000000000000..cf07f011a8c6 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts @@ -0,0 +1,769 @@ +/// +// bench-msm-tree — end-to-end MSM bucket-accumulate benchmark for the +// multi-level pair-tree pipeline: +// 1. marshal-l0 (CSR + chunk_plan + point_pool -> strided chain_buf) +// 2. tree-disjoint level 0 (chain_buf -> level-1 input layout, in-place via ping-pong) +// 3. tree-disjoint level 1 (continues in ping-pong) +// ... +// level L-1 (final): tree-disjoint with `final` flag -> simple strided output +// 4. tail kernel (small buckets count<2*S -> one sum each) +// +// Reports per-stage and combined ns/in-pt for the full bucket-accumulate +// over N points distributed across B buckets. +// +// Modes: +// ?mode=uniform : every bucket has exactly 2*S = 32 points (clean +// multi-level test, no tail). +// ?mode=skewed : Poisson-distributed via uniform random scalar +// assignment. Main pair-tree handles buckets with +// count >= 2*S; tail kernel handles the rest. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { makeResultsClient } from './results_post.js'; + +const PG = 2; +const DEFAULT_N = 1 << 17; // 131072 points +const DEFAULT_BUCKETS = 1 << 12; // 4096 buckets -> uniform avg 32 = 2*S +const DEFAULT_S = 16; +const DEFAULT_WGI = 64; +const DEFAULT_DISP = 4; // dispatch amortisation per timed sample +const DEFAULT_MODE = 'uniform' as const; + +let NPTS = DEFAULT_N; +let BUCKETS = DEFAULT_BUCKETS; +let S = DEFAULT_S; +let WGI = DEFAULT_WGI; +let DISP = DEFAULT_DISP; +let MODE: 'uniform' | 'skewed' = DEFAULT_MODE; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + for (;;) { + let v = 0n; + for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff); + v &= (1n << BigInt(bitlen)) - 1n; + if (v > 0n && v < p) return v; + } +} + +function bigintToPackedU32x8(v: bigint): Uint32Array { + const w = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + w[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return w; +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +function buildPointPool(poolSize: number, R: bigint, p: bigint, rng: () => number): Uint32Array { + const M = poolSize; + const buf = new Uint32Array(2 * PG * M * 4); + for (let e = 0; e < M; e++) { + const x = (randomBelow(p, rng) * R) % p; + const y = (randomBelow(p, rng) * R) % p; + const wx = bigintToPackedU32x8(x); + const wy = bigintToPackedU32x8(y); + for (let v = 0; v < PG; v++) { + const baseX = ((0 * PG + v) * M + e) * 4; + const baseY = ((1 * PG + v) * M + e) * 4; + buf[baseX + 0] = wx[4 * v + 0]; + buf[baseX + 1] = wx[4 * v + 1]; + buf[baseX + 2] = wx[4 * v + 2]; + buf[baseX + 3] = wx[4 * v + 3]; + buf[baseY + 0] = wy[4 * v + 0]; + buf[baseY + 1] = wy[4 * v + 1]; + buf[baseY + 2] = wy[4 * v + 2]; + buf[baseY + 3] = wy[4 * v + 3]; + } + } + return buf; +} + +interface CSR { + csrIndices: Uint32Array; // 1-based: index 0 reserved (decoy/unused) + offsets: Uint32Array; + counts: Uint32Array; +} + +function buildUniformCSR(N: number, B: number, perBucket: number): CSR { + if (N !== B * perBucket) { + throw new Error(`uniform mode requires N=${N} = B*${perBucket}=${B * perBucket}`); + } + const counts = new Uint32Array(B); + const offsets = new Uint32Array(B + 1); + const csrIndices = new Uint32Array(N); + for (let b = 0; b < B; b++) { + counts[b] = perBucket; + offsets[b + 1] = offsets[b] + perBucket; + for (let i = 0; i < perBucket; i++) { + csrIndices[offsets[b] + i] = b * perBucket + i + 1; // 1-based + } + } + return { csrIndices, offsets, counts }; +} + +function buildSkewedCSR(N: number, B: number, rng: () => number): CSR { + const bucket = new Uint32Array(N); + const counts = new Uint32Array(B); + for (let i = 0; i < N; i++) { + const b = rng() % B; + bucket[i] = b; + counts[b]++; + } + const offsets = new Uint32Array(B + 1); + for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b]; + const cursor = new Uint32Array(B); + const csrIndices = new Uint32Array(N); + for (let i = 0; i < N; i++) { + const b = bucket[i]; + csrIndices[offsets[b] + cursor[b]++] = i + 1; // 1-based + } + return { csrIndices, offsets, counts }; +} + +// Split each bucket into (a) full 2*S chunks for the level-0 marshal + +// pair-tree pass and (b) a tail of count mod 2*S points for the tail +// kernel. Returns a chunk_plan for the main pipeline and a tail_plan +// for the tail kernel. +// +// Only buckets where count >= 2*S contribute to the main path. Each +// contributes floor(count / (2*S)) chunks of exactly 2*S points each. +// Remaining count mod 2*S points go to the tail. Buckets with +// count < 2*S go entirely to the tail. NOTE: for v1, the per-bucket +// "extra full-2S chunks already reduced via main path" is *not* +// re-folded into a single per-bucket sum; this would matter for +// correctness but for the bench we just measure dispatch wall-clock. +function buildChunkAndTailPlans( + offsets: Uint32Array, + counts: Uint32Array, + S: number, +): { chunkPlan: Uint32Array; T: number; tailPlan: Uint32Array; TT: number; mainPoints: number; tailPoints: number } { + const B = counts.length; + const blkSize = 2 * S; + let T = 0; + let TT = 0; + let mainPts = 0; + let tailPts = 0; + for (let b = 0; b < B; b++) { + const c = counts[b]; + const nMain = Math.floor(c / blkSize); + T += nMain; + mainPts += nMain * blkSize; + const remain = c - nMain * blkSize; + if (remain > 0) { + TT++; + tailPts += remain; + } + } + const chunkPlan = new Uint32Array(2 * T); + const tailPlan = new Uint32Array(3 * TT); + let t = 0; + let tt = 0; + for (let b = 0; b < B; b++) { + const c = counts[b]; + const nMain = Math.floor(c / blkSize); + for (let k = 0; k < nMain; k++) { + chunkPlan[2 * t + 0] = b; + chunkPlan[2 * t + 1] = offsets[b] + k * blkSize; + t++; + } + const remain = c - nMain * blkSize; + if (remain > 0) { + tailPlan[3 * tt + 0] = b; + tailPlan[3 * tt + 1] = offsets[b] + nMain * blkSize; + tailPlan[3 * tt + 2] = remain; + tt++; + } + } + return { chunkPlan, T, tailPlan, TT, mainPoints: mainPts, tailPoints: tailPts }; +} + +interface KernelTiming { + median_ms: number; + min_ms: number; + max_ms: number; + samples_ms: number[]; +} + +async function compile( + device: GPUDevice, + code: string, + cacheKey: string, + layout: GPUBindGroupLayout, +): Promise { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const m of info.messages) { + const line = `[shader ${cacheKey}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`; + if (m.type === 'error') { + console.error(line); + log('err', line); + errLines.push(line); + hasError = true; + } else { + console.warn(line); + } + } + if (hasError) throw new Error(`WGSL compile failed for ${cacheKey}: ${errLines.slice(0, 4).join(' | ')}`); + return device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); +} + +function marshalLayout(device: GPUDevice): GPUBindGroupLayout { + return device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); +} + +function treeKernelLayout(device: GPUDevice): GPUBindGroupLayout { + return device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); +} + +function tailLayout(device: GPUDevice): GPUBindGroupLayout { + return device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); +} + +async function readNonZero(device: GPUDevice, buf: GPUBuffer, u32Count: number): Promise { + const bytes = u32Count * 4; + const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, bytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const u32 = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + for (let i = 0; i < u32.length; i++) if (u32[i] !== 0) return true; + return false; +} + +async function timeDispatch( + device: GPUDevice, + pipeline: GPUComputePipeline, + bind: GPUBindGroup, + numWgs: number, + reps: number, + passes: number, +): Promise { + // warmup + { + const enc = device.createCommandEncoder(); + for (let p = 0; p < passes; p++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + } + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + const samples: number[] = []; + for (let r = 0; r < reps; r++) { + const enc = device.createCommandEncoder(); + for (let p = 0; p < passes; p++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + samples.push(performance.now() - t0); + } + return { + median_ms: median(samples), + min_ms: Math.min(...samples), + max_ms: Math.max(...samples), + samples_ms: samples, + }; +} + +interface RunResult { + s: number; + wgi: number; + disp: number; + pairs: number; + buckets: number; + mode: string; + T_main: number; + T_tail: number; + main_points: number; + tail_points: number; + levels: number; + marshal_ms: number; + level_ms: number[]; + tail_ms: number; + total_ms: number; + marshal_ns_per_inpt: number; + pair_tree_ns_per_inpt: number; + tail_ns_per_inpt: number; + combined_ns_per_inpt: number; + sanity_ok: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { reps: number; n: number; buckets: number; s: number; wgi: number; disp: number; mode: string } | null; + results: RunResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { + state: 'boot', + params: null, + results: [], + error: null, + log: [], +}; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const resultsClient = makeResultsClient({ page: 'bench-msm-tree' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, + params: benchState.params, + results: benchState.results, + error: benchState.error, + log: benchState.log, + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-msm-tree] ${msg}`); +} + +async function runPipeline( + device: GPUDevice, + sm: ShaderManager, + reps: number, + R: bigint, + p: bigint, +): Promise { + log('info', `=== mode=${MODE} N=${NPTS} B=${BUCKETS} S=${S} WGI=${WGI} DISP=${DISP}`); + + const rng = makeRng(0x4711); + const poolSize = NPTS + 1; // index 0 reserved (1-based) + const poolU32 = buildPointPool(poolSize, R, p, rng); + + const csr: CSR = + MODE === 'uniform' + ? buildUniformCSR(NPTS, BUCKETS, NPTS / BUCKETS) + : buildSkewedCSR(NPTS, BUCKETS, rng); + + const offsets = csr.offsets; + const counts = csr.counts; + const { chunkPlan, T, tailPlan, TT, mainPoints, tailPoints } = buildChunkAndTailPlans(offsets, counts, S); + log( + 'info', + `plan: main T=${T} chunks (${mainPoints} pts) | tail TT=${TT} threads (${tailPoints} pts) | dropped=${NPTS - mainPoints - tailPoints}`, + ); + if (T === 0 && TT === 0) throw new Error('plan is empty'); + + // Determine number of pair-tree levels: starting at T_0 = T, halve + // each level until T == 0 or some safety cap. + const levels: number[] = []; + for (let t = T; t > 0; t = Math.floor(t / 2)) { + levels.push(t); + if (t === 1) break; + if (levels.length > 24) throw new Error('too many tree levels'); + } + log('info', `pair-tree levels: ${levels.length} (T sequence: ${levels.join(' -> ')})`); + + // Buffers. + const mkSb = (size: number, copyDst: boolean, copySrc: boolean): GPUBuffer => { + let usage = GPUBufferUsage.STORAGE; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + return device.createBuffer({ size, usage }); + }; + + const poolBuf = mkSb(poolU32.byteLength, true, false); + device.queue.writeBuffer(poolBuf, 0, poolU32); + let csrBuf: GPUBuffer | null = null; + let chunkBuf: GPUBuffer | null = null; + let tailPlanBuf: GPUBuffer | null = null; + if (T > 0) { + csrBuf = mkSb(csr.csrIndices.byteLength, true, false); + device.queue.writeBuffer(csrBuf, 0, csr.csrIndices); + chunkBuf = mkSb(chunkPlan.byteLength, true, false); + device.queue.writeBuffer(chunkBuf, 0, chunkPlan); + } else if (TT > 0) { + csrBuf = mkSb(csr.csrIndices.byteLength, true, false); + device.queue.writeBuffer(csrBuf, 0, csr.csrIndices); + } + if (TT > 0) { + tailPlanBuf = mkSb(tailPlan.byteLength, true, false); + device.queue.writeBuffer(tailPlanBuf, 0, tailPlan); + } + + // Ping-pong buffers for the pair-tree. Each plane needs at most + // 2*S*T_0 vec4 (level-0 input size). Output of level k is sized + // S*T_k vec4 per plane, which is half. Use two buffers of the same + // size and ping-pong. + let bufA: GPUBuffer | null = null; + let bufB: GPUBuffer | null = null; + if (T > 0) { + const planeBytes = 2 * PG * (2 * S * T) * 4 * 4; // 2 planes (P.x, P.y) * PG vec4 * (2*S*T elems) * 4 u32/vec4 * 4 B + bufA = mkSb(planeBytes, false, true); + bufB = mkSb(planeBytes, false, true); + } + + // Bucket-sums buffer (tail output): 2 planes (P.x, P.y), PG vec4 per + // bucket. Pre-zero implicitly (GPU buffers start zeroed in WebGPU). + const bucketSumsBytes = 2 * PG * BUCKETS * 4 * 4; + const bucketSumsBuf = mkSb(bucketSumsBytes, false, true); + + const paramsBytes = 16; + const marshalParams = device.createBuffer({ size: paramsBytes, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + const levelParams: GPUBuffer[] = []; + for (let i = 0; i < levels.length; i++) { + levelParams.push(device.createBuffer({ size: paramsBytes, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST })); + } + const tailParams = device.createBuffer({ size: paramsBytes, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + + // Compile + bind. + let marshalPipeline: GPUComputePipeline | null = null; + let marshalBind: GPUBindGroup | null = null; + let marshalWgs = 0; + if (T > 0 && bufA !== null && csrBuf !== null && chunkBuf !== null) { + const code = sm.gen_ba_marshal_tree_l0_bench_shader(WGI, S); + log('info', `marshal-l0 shader: ${code.length} chars`); + const mL = marshalLayout(device); + marshalPipeline = await compile(device, code, `marshal-l0-W${WGI}-S${S}`, mL); + marshalBind = device.createBindGroup({ + layout: mL, + entries: [ + { binding: 0, resource: { buffer: csrBuf } }, + { binding: 1, resource: { buffer: chunkBuf } }, + { binding: 2, resource: { buffer: poolBuf } }, + { binding: 3, resource: { buffer: bufA } }, + { binding: 4, resource: { buffer: marshalParams } }, + ], + }); + device.queue.writeBuffer(marshalParams, 0, new Uint32Array([T, poolSize, 0, 0])); + marshalWgs = Math.ceil(T / WGI); + } + + // Tree kernel: compile once, bind per-level with the appropriate + // (input, output, params) trio. + let treePipeline: GPUComputePipeline | null = null; + const treeBinds: GPUBindGroup[] = []; + const treeNumWgs: number[] = []; + if (T > 0 && bufA !== null && bufB !== null) { + const code = sm.gen_ba_pair_disjoint_tree_bench_shader(WGI, S); + log('info', `tree-disjoint shader: ${code.length} chars`); + const tL = treeKernelLayout(device); + treePipeline = await compile(device, code, `tree-disjoint-W${WGI}-S${S}`, tL); + const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE }); + let curIn = bufA; + let curOut = bufB; + for (let lv = 0; lv < levels.length; lv++) { + const T_lv = levels[lv]; + const N_in_lv = 2 * S * T_lv; + const isFinal = lv === levels.length - 1 ? 1 : 0; + device.queue.writeBuffer(levelParams[lv], 0, new Uint32Array([N_in_lv, T_lv, isFinal, 0])); + treeBinds.push( + device.createBindGroup({ + layout: tL, + entries: [ + { binding: 0, resource: { buffer: curIn } }, + { binding: 1, resource: { buffer: dummy } }, + { binding: 2, resource: { buffer: curOut } }, + { binding: 3, resource: { buffer: levelParams[lv] } }, + ], + }), + ); + treeNumWgs.push(Math.ceil(T_lv / WGI)); + const swap = curIn; + curIn = curOut; + curOut = swap; + } + } + + // Tail pipeline. + let tailPipeline: GPUComputePipeline | null = null; + let tailBind: GPUBindGroup | null = null; + let tailWgs = 0; + if (TT > 0 && csrBuf !== null && tailPlanBuf !== null) { + const code = sm.gen_ba_tail_reduce_bench_shader(WGI, S); + log('info', `tail shader: ${code.length} chars`); + const tL = tailLayout(device); + tailPipeline = await compile(device, code, `tail-W${WGI}-S${S}`, tL); + tailBind = device.createBindGroup({ + layout: tL, + entries: [ + { binding: 0, resource: { buffer: csrBuf } }, + { binding: 1, resource: { buffer: tailPlanBuf } }, + { binding: 2, resource: { buffer: poolBuf } }, + { binding: 3, resource: { buffer: bucketSumsBuf } }, + { binding: 4, resource: { buffer: tailParams } }, + ], + }); + device.queue.writeBuffer(tailParams, 0, new Uint32Array([TT, poolSize, BUCKETS, 0])); + tailWgs = Math.ceil(TT / WGI); + } + + // Warmup once. + if (marshalPipeline && marshalBind) { + const enc = device.createCommandEncoder(); + const pass = enc.beginComputePass(); + pass.setPipeline(marshalPipeline); + pass.setBindGroup(0, marshalBind); + pass.dispatchWorkgroups(marshalWgs, 1, 1); + pass.end(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + if (treePipeline) { + for (let lv = 0; lv < treeBinds.length; lv++) { + const enc = device.createCommandEncoder(); + const pass = enc.beginComputePass(); + pass.setPipeline(treePipeline); + pass.setBindGroup(0, treeBinds[lv]); + pass.dispatchWorkgroups(treeNumWgs[lv], 1, 1); + pass.end(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + } + if (tailPipeline && tailBind) { + const enc = device.createCommandEncoder(); + const pass = enc.beginComputePass(); + pass.setPipeline(tailPipeline); + pass.setBindGroup(0, tailBind); + pass.dispatchWorkgroups(tailWgs, 1, 1); + pass.end(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + + // Time each stage separately, DISP back-to-back per sample. + let marshalTiming: KernelTiming | null = null; + if (marshalPipeline && marshalBind) { + marshalTiming = await timeDispatch(device, marshalPipeline, marshalBind, marshalWgs, reps, DISP); + } + const levelTimings: KernelTiming[] = []; + if (treePipeline) { + for (let lv = 0; lv < treeBinds.length; lv++) { + const t = await timeDispatch(device, treePipeline, treeBinds[lv], treeNumWgs[lv], reps, DISP); + levelTimings.push(t); + } + } + let tailTiming: KernelTiming | null = null; + if (tailPipeline && tailBind) { + tailTiming = await timeDispatch(device, tailPipeline, tailBind, tailWgs, reps, DISP); + } + + // Sanity: at least one of the output buffers must have nonzero data. + let sanity = false; + if (treePipeline && bufA && bufB) { + const finalBuf = treeBinds.length % 2 === 0 ? bufA : bufB; + sanity = sanity || (await readNonZero(device, finalBuf, 8)); + } + if (tailPipeline) { + sanity = sanity || (await readNonZero(device, bucketSumsBuf, 8)); + } + + // Compute per-stage ns/in-pt (normalised to total points fed to that + // stage; DISP-amortised wall clock). + const totalInPts = mainPoints + tailPoints; + const marshalMed = marshalTiming?.median_ms ?? 0; + const marshalNs = (marshalMed * 1e6) / (mainPoints * DISP); + const treeTotalMed = levelTimings.reduce((acc, t) => acc + t.median_ms, 0); + const treeNs = (treeTotalMed * 1e6) / (mainPoints * DISP); + const tailMed = tailTiming?.median_ms ?? 0; + const tailNs = TT > 0 ? (tailMed * 1e6) / (tailPoints * DISP) : 0; + const combinedTotal = marshalMed + treeTotalMed + tailMed; + const combinedNs = (combinedTotal * 1e6) / (totalInPts * DISP); + + log( + sanity ? 'ok' : 'err', + `marshal=${marshalNs.toFixed(2)}ns/pt pair_tree=${treeNs.toFixed(2)}ns/pt tail=${tailNs.toFixed(2)}ns/pt | combined=${combinedNs.toFixed(2)}ns/in-pt | sanity=${sanity ? 'OK' : 'FAIL'}`, + ); + for (let lv = 0; lv < levelTimings.length; lv++) { + log('info', ` level ${lv} T=${levels[lv]}: median=${levelTimings[lv].median_ms.toFixed(3)}ms min=${levelTimings[lv].min_ms.toFixed(3)}ms`); + } + + // Cleanup + poolBuf.destroy(); + csrBuf?.destroy(); + chunkBuf?.destroy(); + tailPlanBuf?.destroy(); + bufA?.destroy(); + bufB?.destroy(); + bucketSumsBuf.destroy(); + marshalParams.destroy(); + for (const b of levelParams) b.destroy(); + tailParams.destroy(); + + return { + s: S, + wgi: WGI, + disp: DISP, + pairs: NPTS, + buckets: BUCKETS, + mode: MODE, + T_main: T, + T_tail: TT, + main_points: mainPoints, + tail_points: tailPoints, + levels: levelTimings.length, + marshal_ms: marshalMed, + level_ms: levelTimings.map(t => t.median_ms), + tail_ms: tailMed, + total_ms: combinedTotal, + marshal_ns_per_inpt: marshalNs, + pair_tree_ns_per_inpt: treeNs, + tail_ns_per_inpt: tailNs, + combined_ns_per_inpt: combinedNs, + sanity_ok: sanity, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + const reps = parseInt(qp.get('reps') ?? '5', 10); + if (!Number.isFinite(reps) || reps <= 0 || reps > 50) throw new Error(`?reps must be in (0, 50]`); + if (qp.get('n')) { + const v = parseInt(qp.get('n')!, 10); + if (!Number.isFinite(v) || v <= 0 || v > (1 << 20)) throw new Error(`?n must be in (0, 2^20]`); + NPTS = v; + } + if (qp.get('buckets')) { + const v = parseInt(qp.get('buckets')!, 10); + if (!Number.isFinite(v) || v <= 0 || v > (1 << 18)) throw new Error(`?buckets must be in (0, 2^18]`); + BUCKETS = v; + } + if (qp.get('s')) { + const v = parseInt(qp.get('s')!, 10); + if (!Number.isFinite(v) || v <= 0 || v > 256) throw new Error(`?s must be in (0, 256]`); + S = v; + } + if (qp.get('wgi')) { + const v = parseInt(qp.get('wgi')!, 10); + if (!Number.isFinite(v) || v <= 0 || v > 1024) throw new Error(`?wgi must be in (0, 1024]`); + WGI = v; + } + if (qp.get('disp')) { + const v = parseInt(qp.get('disp')!, 10); + if (!Number.isFinite(v) || v <= 0 || v > 64) throw new Error(`?disp must be in (0, 64]`); + DISP = v; + } + if (qp.get('mode')) { + const v = qp.get('mode')!; + if (v !== 'uniform' && v !== 'skewed') throw new Error(`?mode must be uniform or skewed`); + MODE = v; + } + return { reps, n: NPTS, buckets: BUCKETS, s: S, wgi: WGI, disp: DISP, mode: MODE }; +} + +async function main() { + try { + if (!('gpu' in navigator)) throw new Error('navigator.gpu missing — WebGPU not available'); + const params = parseParams(); + benchState.params = params; + log( + 'info', + `params: reps=${params.reps} n=${params.n} buckets=${params.buckets} s=${params.s} wgi=${params.wgi} disp=${params.disp} mode=${params.mode}`, + ); + + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + + const sm = new ShaderManager(4, NPTS, BN254_CURVE_CONFIG, false); + + const r = await runPipeline(device, sm, params.reps, R, p); + benchState.results.push(r); + resultsClient.postProgress({ + kind: 'pipeline_done', + mode: r.mode, + combined_ns_per_inpt: r.combined_ns_per_inpt, + sanity_ok: r.sanity_ok, + }); + + benchState.state = 'done'; + log('ok', 'pipeline complete'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + }) + .finally(() => { + postFinal().catch(() => {}); + }); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs index be4d0fba17d1..2ad069d89548 100644 --- a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs +++ b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs @@ -131,6 +131,7 @@ const pageMap = { "bench-ba-rev-packed-carry": "/dev/msm-webgpu/bench-ba-rev-packed-carry.html", "bench-msm-chain": "/dev/msm-webgpu/bench-msm-chain.html", "bench-ba-pair-disjoint": "/dev/msm-webgpu/bench-ba-pair-disjoint.html", + "bench-msm-tree": "/dev/msm-webgpu/bench-msm-tree.html", "bench-smvp-tree": "/dev/msm-webgpu/bench-smvp-tree.html", sanity: "/dev/msm-webgpu/index.html", }; diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index 0c1a093b461d..a114752776e5 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -6,7 +6,10 @@ import { batch_affine_apply_scatter as batch_affine_apply_scatter_shader, batch_affine_dispatch_args as batch_affine_dispatch_args_shader, ba_marshal_chain_bench as ba_marshal_chain_bench_shader, + ba_marshal_tree_l0_bench as ba_marshal_tree_l0_bench_shader, ba_pair_disjoint_bench as ba_pair_disjoint_bench_shader, + ba_pair_disjoint_tree_bench as ba_pair_disjoint_tree_bench_shader, + ba_tail_reduce_bench as ba_tail_reduce_bench_shader, ba_rev_packed_carry_bench as ba_rev_packed_carry_bench_shader, bench_batch_affine as bench_batch_affine_shader, batch_affine_finalize as batch_affine_finalize_shader, @@ -783,6 +786,110 @@ ${packLines.join('\n')} ); } + /** + * Tree variant of the disjoint pair-sum kernel: writes outputs in the + * layout the next pair-tree level expects as input, so multi-level + * reductions can chain without an intervening marshal pass. Per + * thread: 2*S inputs -> S disjoint pair sums. Final-level flag (via + * params.z) switches to a simple strided write for the last pass. + */ + public gen_ba_pair_disjoint_tree_bench_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_pair_disjoint_tree_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + const dec = this.decoupledPackUnpackWgsl(); + return mustache.render( + ba_pair_disjoint_tree_bench_shader, + { + workgroup_size, + s, + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, + dec_unpack: dec.unpack, + dec_pack: dec.pack, + recompile: this.recompile, + }, + { + structs, + bigint_funcs, + montgomery_product_funcs: this.mont_product_src, + field_funcs, + fr_pow_funcs, + bigint_by_funcs, + by_inverse_a_funcs, + }, + ); + } + + /** + * Level-0 marshal kernel for the bench-msm-tree pipeline: transposes + * CSR-sorted point indices into the 2-plane strided SoA layout the + * tree-disjoint kernel reads. Pure memory shuffle. + */ + public gen_ba_marshal_tree_l0_bench_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_marshal_tree_l0_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + return mustache.render( + ba_marshal_tree_l0_bench_shader, + { workgroup_size, s, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Tail kernel for the bench-msm-tree pipeline: one thread per small + * bucket (count < 2*S), serial per-thread chain with one fr_inv_by_a + * per add (no batched inversion). Bounded loop up to compile-time + * TAIL_CAP = 2*S - 1. + */ + public gen_ba_tail_reduce_bench_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_tail_reduce_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + const tail_cap = 2 * s - 1; + const dec = this.decoupledPackUnpackWgsl(); + return mustache.render( + ba_tail_reduce_bench_shader, + { + workgroup_size, + tail_cap, + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, + dec_unpack: dec.unpack, + dec_pack: dec.pack, + recompile: this.recompile, + }, + { + structs, + bigint_funcs, + montgomery_product_funcs: this.mont_product_src, + field_funcs, + fr_pow_funcs, + bigint_by_funcs, + by_inverse_a_funcs, + }, + ); + } + /** * Standalone microbench for the disjoint pair-sum kernel: each * thread reduces 2*S input points to S disjoint pair sums R_k = diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index 43910d8fb1a1..c5c6b8206fb9 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1,6 +1,6 @@ // AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT. // Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate. -// 53 shader sources inlined. +// 56 shader sources inlined. /* eslint-disable */ @@ -1444,6 +1444,72 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_marshal_tree_l0_bench = `{{> structs }} + +// Marshal kernel for the bench-msm-tree pair-tree pipeline: transposes +// a CSR-sorted point index list into the 2-plane strided SoA layout +// the ba_pair_disjoint_tree kernel consumes at level 0. Pure memory +// shuffle, no field arithmetic. +// +// Input (point_pool): +// 2 planes (P.x, P.y), each PG=2 vec4 per element, N pool elements. +// Plane p flat vec4 indices: p*PG*N + PG*i + {0,1}. +// +// Output (chain_buf): +// 2 planes (P.x, P.y), each PG=2 vec4 per element, 2*S*T elements +// per plane. Plane p at strided element e = t + i*T: vec4 indices +// p*PG*(2*S*T) + PG*e + {0,1}. +// +// Per chunk-thread t with CSR slice [csr_start, csr_start + 2*S): +// For i in 0..2*S: +// pt_idx = csr_indices[csr_start + i] +// copy point_pool[pt_idx] (P.x, P.y) into chain_buf at e = t + i*T + +const S: u32 = {{ s }}u; +const TWOS: u32 = 2u * S; +const PG: u32 = 2u; + +@group(0) @binding(0) var csr_indices: array; +@group(0) @binding(1) var chunk_plan: array; +@group(0) @binding(2) var point_pool: array>; +@group(0) @binding(3) var chain_buf: array>; +@group(0) @binding(4) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let N = params.y; + let t = gid.x; + if (t >= T) { return; } + + let csr_start = chunk_plan[2u * t + 1u]; + + let chain_N = TWOS * T; + let chain_plane = PG * chain_N; + let chain_px_base = 0u * chain_plane; + let chain_py_base = 1u * chain_plane; + + let pool_plane = PG * N; + let pool_px_base = 0u * pool_plane; + let pool_py_base = 1u * pool_plane; + + for (var i: u32 = 0u; i < TWOS; i = i + 1u) { + let pt_idx = csr_indices[csr_start + i]; + let e = t + i * T; + let pool_x_off = pool_px_base + PG * pt_idx; + let pool_y_off = pool_py_base + PG * pt_idx; + let chain_px_off = chain_px_base + PG * e; + let chain_py_off = chain_py_base + PG * e; + chain_buf[chain_px_off + 0u] = point_pool[pool_x_off + 0u]; + chain_buf[chain_px_off + 1u] = point_pool[pool_x_off + 1u]; + chain_buf[chain_py_off + 0u] = point_pool[pool_y_off + 0u]; + chain_buf[chain_py_off + 1u] = point_pool[pool_y_off + 1u]; + } + + {{{ recompile }}} +} +`; + export const ba_pair_disjoint_bench = `{{> structs }} {{> bigint_funcs }} {{> montgomery_product_funcs }} @@ -1584,6 +1650,177 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_pair_disjoint_tree_bench = `{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Disjoint pair-sum kernel — tree variant. Each thread reduces 2*S +// input points to S disjoint pair sums R_k = P_{2k} + P_{2k+1}, using +// one batched fr_inv_by_a per chunk of S. +// +// vs ba_pair_disjoint_bench: writes outputs in the LAYOUT THE NEXT +// PAIR-TREE LEVEL EXPECTS AS INPUT, eliminating the need for an +// intervening marshal/reshuffle dispatch between levels. +// +// Strided read at level k: thread t reads input slot i at flat +// in_pos(t, i) = t + i * T_curr (i in [0, 2*S)) +// +// Strided write that next level reads correctly: thread t writes +// output slot i at flat +// out_pos(t, i) = (t >> 1) + (i + S * (t & 1)) * (T_curr >> 1) +// +// Derivation: next level uses T_next = T_curr / 2 threads. For +// next-level thread t_n = t >> 1 to read its 2*S inputs in the right +// pair-tree order (first S from prev thread (2*t_n), next S from prev +// thread (2*t_n + 1)), the current level's output slots interleave: +// odd-t writes go into the upper-S input slots of the next level's +// thread (t >> 1), even-t into the lower-S slots. +// +// This preserves the per-bucket-pair invariant: at every level, the +// disjoint pairs (P_{2j}, P_{2j+1}) belong to the same bucket pool, +// so the lean affine formula is always combining points whose dx is +// well-defined. +// +// PARAMS: +// params.x = N_in = 2 * S * T_curr (total input elements per plane) +// params.y = T_curr +// +// LAYOUT (both input and output buffers): +// 2 planes (P.x, P.y), PG=2 vec4 per element. +// Plane p flat index for vec4 access: p * PG * N_buf + PG * e + {0,1} +// where N_buf is the elements-per-plane for that buffer. +// Input buffer's N_buf = 2 * S * T_curr (= N_in). +// Output buffer's N_buf = S * T_curr (= N_in / 2). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var inp: array>; +@group(0) @binding(1) var unused: array>; +@group(0) @binding(2) var outp: array>; +@group(0) @binding(3) var params: vec4; + +fn load_in(plane: u32, t: u32, i: u32, T: u32, N_in: u32) -> BigInt { + let plane_base = plane * PG * N_in; + let base = plane_base + PG * (t + i * T); + let q0 = inp[base + 0u]; + let q1 = inp[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_out_tree(plane: u32, t: u32, k: u32, T_curr: u32, N_out: u32, val: ptr) { + // Tree write: out_pos(t, k) = (t >> 1) + (k + S * (t & 1)) * (T_curr >> 1) + // Lands in next-level strided read at index (t >> 1) with slot + // (k + S * (t & 1)). + let t_next = t >> 1u; + let slot_in_next = k + S * (t & 1u); + let T_next = T_curr >> 1u; + let plane_base = plane * PG * N_out; + let elem = t_next + slot_in_next * T_next; + let base = plane_base + PG * elem; + let w = pack_limbs_to_256(val); + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn store_out_simple(plane: u32, t: u32, k: u32, T_curr: u32, N_out: u32, val: ptr) { + // Final-level simple strided write: out_pos(t, k) = t + k * T_curr. + // Used when there is no next pair-tree level (T_curr == 1 thread, or + // the host indicates this is the last reduction step). + let plane_base = plane * PG * N_out; + let elem = t + k * T_curr; + let base = plane_base + PG * elem; + let w = pack_limbs_to_256(val); + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let N_in = params.x; + let T_curr = params.y; + let final_flag = params.z; // non-zero => use simple strided write + let N_out = N_in / 2u; + + let t = gid.x; + if (t >= T_curr) { return; } + + var pref: array; + var acc: BigInt = get_r(); + for (var k: u32 = 0u; k < S; k = k + 1u) { + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in); + var dx: BigInt = fr_sub(&p_rx, &p_lx); + if (k == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[k] = acc; + } + + var inv: BigInt = fr_inv_by_a(acc); + + for (var jj: u32 = 0u; jj < S; jj = jj + 1u) { + let k = S - 1u - jj; + + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in); + var p_ly: BigInt = load_in(1u, t, 2u * k + 0u, T_curr, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in); + var p_ry: BigInt = load_in(1u, t, 2u * k + 1u, T_curr, N_in); + + var inv_dx: BigInt; + if (k == 0u) { + inv_dx = inv; + } else { + var pp = pref[k - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda: BigInt = fr_sub(&p_ry, &p_ly); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_lx); + r_x = fr_sub(&r_x, &p_rx); + var r_y: BigInt = fr_sub(&p_lx, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &p_ly); + + if (final_flag != 0u) { + store_out_simple(0u, t, k, T_curr, N_out, &r_x); + store_out_simple(1u, t, k, T_curr, N_out, &r_y); + } else { + store_out_tree(0u, t, k, T_curr, N_out, &r_x); + store_out_tree(1u, t, k, T_curr, N_out, &r_y); + } + + if (k > 0u) { + var dx_back: BigInt = fr_sub(&p_rx, &p_lx); + inv = montgomery_product(&inv, &dx_back); + } + } + + {{{ recompile }}} +} +`; + export const ba_rev_packed_carry_bench = `{{> structs }} {{> bigint_funcs }} {{> montgomery_product_funcs }} @@ -1758,6 +1995,120 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_tail_reduce_bench = `{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Tail kernel for the bench-msm-tree pipeline: reduces a single +// tail-sized bucket (count < 2*S) to one sum per thread. Each thread +// reads its bucket's count points sequentially from the SoA-packed +// point pool and accumulates them via direct affine adds (one +// fr_inv_by_a per step). +// +// Pragmatic v1 — no batched inversion across threads. Each step pays +// one full fr_inv_by_a (~80 mont mul equivalents). For typical +// Poisson(lambda=16) MSM workloads, tail buckets carry a minority of +// total work (~10-30%); the contribution to overall bucket-accumulate +// ns/in-pt is small enough that this simple design is acceptable for +// a v1 complete-replacement kernel set. A workgroup-scan +// batched-inversion variant is a follow-on optimisation that would +// drop tail cost to ~25 ns/add (matching the main pair-tree). +// +// Bindings: +// binding 0: csr_indices — sorted point indices, 1-based (index 0 reserved). +// binding 1: tail_plan — three u32 per tail thread: +// [bucket_id, csr_start, count]. +// binding 2: point_pool — SoA-packed pool (2 planes, PG=2 vec4/elem). +// binding 3: bucket_sums — SoA-packed output (2 planes, PG=2 vec4/bucket), +// one packed point per bucket. Pre-zeroed by host. +// binding 4: params — params.x=T (tail thread count), +// params.y=N (pool size), +// params.z=B (bucket_sums slot count). +// +// Bounded loop: the per-thread accumulate loop iterates up to compile- +// time TAIL_CAP = 2*S - 1, breaking early when i >= count. No +// data-dependent unbounded loops. + +const TAIL_CAP: u32 = {{ tail_cap }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var csr_indices: array; +@group(0) @binding(1) var tail_plan: array; +@group(0) @binding(2) var point_pool: array>; +@group(0) @binding(3) var bucket_sums: array>; +@group(0) @binding(4) var params: vec4; + +fn load_pool(plane: u32, idx: u32, N: u32) -> BigInt { + let plane_base = plane * PG * N; + let base = plane_base + PG * idx; + let q0 = point_pool[base + 0u]; + let q1 = point_pool[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_bucket(plane: u32, b: u32, B: u32, val: ptr) { + let plane_base = plane * PG * B; + let base = plane_base + PG * b; + let w = pack_limbs_to_256(val); + bucket_sums[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + bucket_sums[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let N = params.y; + let B = params.z; + + let t = gid.x; + if (t >= T) { return; } + + let bucket_id = tail_plan[3u * t + 0u]; + let csr_start = tail_plan[3u * t + 1u]; + let count = tail_plan[3u * t + 2u]; + + if (count == 0u) { return; } + + var acc_x: BigInt = load_pool(0u, csr_indices[csr_start], N); + var acc_y: BigInt = load_pool(1u, csr_indices[csr_start], N); + + for (var i: u32 = 1u; i < TAIL_CAP; i = i + 1u) { + if (i >= count) { break; } + let pt_idx = csr_indices[csr_start + i]; + var p_x: BigInt = load_pool(0u, pt_idx, N); + var p_y: BigInt = load_pool(1u, pt_idx, N); + var dx: BigInt = fr_sub(&p_x, &acc_x); + var inv_dx: BigInt = fr_inv_by_a(dx); + var dy: BigInt = fr_sub(&p_y, &acc_y); + var lambda: BigInt = montgomery_product(&dy, &inv_dx); + var lambda_sq: BigInt = montgomery_product(&lambda, &lambda); + var r_x: BigInt = fr_sub(&lambda_sq, &acc_x); + r_x = fr_sub(&r_x, &p_x); + var r_y: BigInt = fr_sub(&acc_x, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &acc_y); + acc_x = r_x; + acc_y = r_y; + } + + store_bucket(0u, bucket_id, B, &acc_x); + store_bucket(1u, bucket_id, B, &acc_y); + + {{{ recompile }}} +} +`; + export const barrett = `const W_MASK = {{ w_mask }}u; const SLACK = {{ slack }}u; diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_tree_l0_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_tree_l0_bench.template.wgsl new file mode 100644 index 000000000000..4b1539600e64 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_tree_l0_bench.template.wgsl @@ -0,0 +1,64 @@ +{{> structs }} + +// Marshal kernel for the bench-msm-tree pair-tree pipeline: transposes +// a CSR-sorted point index list into the 2-plane strided SoA layout +// the ba_pair_disjoint_tree kernel consumes at level 0. Pure memory +// shuffle, no field arithmetic. +// +// Input (point_pool): +// 2 planes (P.x, P.y), each PG=2 vec4 per element, N pool elements. +// Plane p flat vec4 indices: p*PG*N + PG*i + {0,1}. +// +// Output (chain_buf): +// 2 planes (P.x, P.y), each PG=2 vec4 per element, 2*S*T elements +// per plane. Plane p at strided element e = t + i*T: vec4 indices +// p*PG*(2*S*T) + PG*e + {0,1}. +// +// Per chunk-thread t with CSR slice [csr_start, csr_start + 2*S): +// For i in 0..2*S: +// pt_idx = csr_indices[csr_start + i] +// copy point_pool[pt_idx] (P.x, P.y) into chain_buf at e = t + i*T + +const S: u32 = {{ s }}u; +const TWOS: u32 = 2u * S; +const PG: u32 = 2u; + +@group(0) @binding(0) var csr_indices: array; +@group(0) @binding(1) var chunk_plan: array; +@group(0) @binding(2) var point_pool: array>; +@group(0) @binding(3) var chain_buf: array>; +@group(0) @binding(4) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let N = params.y; + let t = gid.x; + if (t >= T) { return; } + + let csr_start = chunk_plan[2u * t + 1u]; + + let chain_N = TWOS * T; + let chain_plane = PG * chain_N; + let chain_px_base = 0u * chain_plane; + let chain_py_base = 1u * chain_plane; + + let pool_plane = PG * N; + let pool_px_base = 0u * pool_plane; + let pool_py_base = 1u * pool_plane; + + for (var i: u32 = 0u; i < TWOS; i = i + 1u) { + let pt_idx = csr_indices[csr_start + i]; + let e = t + i * T; + let pool_x_off = pool_px_base + PG * pt_idx; + let pool_y_off = pool_py_base + PG * pt_idx; + let chain_px_off = chain_px_base + PG * e; + let chain_py_off = chain_py_base + PG * e; + chain_buf[chain_px_off + 0u] = point_pool[pool_x_off + 0u]; + chain_buf[chain_px_off + 1u] = point_pool[pool_x_off + 1u]; + chain_buf[chain_py_off + 0u] = point_pool[pool_y_off + 0u]; + chain_buf[chain_py_off + 1u] = point_pool[pool_y_off + 1u]; + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_tree_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_tree_bench.template.wgsl new file mode 100644 index 000000000000..d12b6176ede1 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_tree_bench.template.wgsl @@ -0,0 +1,169 @@ +{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Disjoint pair-sum kernel — tree variant. Each thread reduces 2*S +// input points to S disjoint pair sums R_k = P_{2k} + P_{2k+1}, using +// one batched fr_inv_by_a per chunk of S. +// +// vs ba_pair_disjoint_bench: writes outputs in the LAYOUT THE NEXT +// PAIR-TREE LEVEL EXPECTS AS INPUT, eliminating the need for an +// intervening marshal/reshuffle dispatch between levels. +// +// Strided read at level k: thread t reads input slot i at flat +// in_pos(t, i) = t + i * T_curr (i in [0, 2*S)) +// +// Strided write that next level reads correctly: thread t writes +// output slot i at flat +// out_pos(t, i) = (t >> 1) + (i + S * (t & 1)) * (T_curr >> 1) +// +// Derivation: next level uses T_next = T_curr / 2 threads. For +// next-level thread t_n = t >> 1 to read its 2*S inputs in the right +// pair-tree order (first S from prev thread (2*t_n), next S from prev +// thread (2*t_n + 1)), the current level's output slots interleave: +// odd-t writes go into the upper-S input slots of the next level's +// thread (t >> 1), even-t into the lower-S slots. +// +// This preserves the per-bucket-pair invariant: at every level, the +// disjoint pairs (P_{2j}, P_{2j+1}) belong to the same bucket pool, +// so the lean affine formula is always combining points whose dx is +// well-defined. +// +// PARAMS: +// params.x = N_in = 2 * S * T_curr (total input elements per plane) +// params.y = T_curr +// +// LAYOUT (both input and output buffers): +// 2 planes (P.x, P.y), PG=2 vec4 per element. +// Plane p flat index for vec4 access: p * PG * N_buf + PG * e + {0,1} +// where N_buf is the elements-per-plane for that buffer. +// Input buffer's N_buf = 2 * S * T_curr (= N_in). +// Output buffer's N_buf = S * T_curr (= N_in / 2). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var inp: array>; +@group(0) @binding(1) var unused: array>; +@group(0) @binding(2) var outp: array>; +@group(0) @binding(3) var params: vec4; + +fn load_in(plane: u32, t: u32, i: u32, T: u32, N_in: u32) -> BigInt { + let plane_base = plane * PG * N_in; + let base = plane_base + PG * (t + i * T); + let q0 = inp[base + 0u]; + let q1 = inp[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_out_tree(plane: u32, t: u32, k: u32, T_curr: u32, N_out: u32, val: ptr) { + // Tree write: out_pos(t, k) = (t >> 1) + (k + S * (t & 1)) * (T_curr >> 1) + // Lands in next-level strided read at index (t >> 1) with slot + // (k + S * (t & 1)). + let t_next = t >> 1u; + let slot_in_next = k + S * (t & 1u); + let T_next = T_curr >> 1u; + let plane_base = plane * PG * N_out; + let elem = t_next + slot_in_next * T_next; + let base = plane_base + PG * elem; + let w = pack_limbs_to_256(val); + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn store_out_simple(plane: u32, t: u32, k: u32, T_curr: u32, N_out: u32, val: ptr) { + // Final-level simple strided write: out_pos(t, k) = t + k * T_curr. + // Used when there is no next pair-tree level (T_curr == 1 thread, or + // the host indicates this is the last reduction step). + let plane_base = plane * PG * N_out; + let elem = t + k * T_curr; + let base = plane_base + PG * elem; + let w = pack_limbs_to_256(val); + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let N_in = params.x; + let T_curr = params.y; + let final_flag = params.z; // non-zero => use simple strided write + let N_out = N_in / 2u; + + let t = gid.x; + if (t >= T_curr) { return; } + + var pref: array; + var acc: BigInt = get_r(); + for (var k: u32 = 0u; k < S; k = k + 1u) { + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in); + var dx: BigInt = fr_sub(&p_rx, &p_lx); + if (k == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[k] = acc; + } + + var inv: BigInt = fr_inv_by_a(acc); + + for (var jj: u32 = 0u; jj < S; jj = jj + 1u) { + let k = S - 1u - jj; + + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in); + var p_ly: BigInt = load_in(1u, t, 2u * k + 0u, T_curr, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in); + var p_ry: BigInt = load_in(1u, t, 2u * k + 1u, T_curr, N_in); + + var inv_dx: BigInt; + if (k == 0u) { + inv_dx = inv; + } else { + var pp = pref[k - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda: BigInt = fr_sub(&p_ry, &p_ly); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_lx); + r_x = fr_sub(&r_x, &p_rx); + var r_y: BigInt = fr_sub(&p_lx, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &p_ly); + + if (final_flag != 0u) { + store_out_simple(0u, t, k, T_curr, N_out, &r_x); + store_out_simple(1u, t, k, T_curr, N_out, &r_y); + } else { + store_out_tree(0u, t, k, T_curr, N_out, &r_x); + store_out_tree(1u, t, k, T_curr, N_out, &r_y); + } + + if (k > 0u) { + var dx_back: BigInt = fr_sub(&p_rx, &p_lx); + inv = montgomery_product(&inv, &dx_back); + } + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_tail_reduce_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_tail_reduce_bench.template.wgsl new file mode 100644 index 000000000000..67143057bfa3 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_tail_reduce_bench.template.wgsl @@ -0,0 +1,112 @@ +{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Tail kernel for the bench-msm-tree pipeline: reduces a single +// tail-sized bucket (count < 2*S) to one sum per thread. Each thread +// reads its bucket's count points sequentially from the SoA-packed +// point pool and accumulates them via direct affine adds (one +// fr_inv_by_a per step). +// +// Pragmatic v1 — no batched inversion across threads. Each step pays +// one full fr_inv_by_a (~80 mont mul equivalents). For typical +// Poisson(lambda=16) MSM workloads, tail buckets carry a minority of +// total work (~10-30%); the contribution to overall bucket-accumulate +// ns/in-pt is small enough that this simple design is acceptable for +// a v1 complete-replacement kernel set. A workgroup-scan +// batched-inversion variant is a follow-on optimisation that would +// drop tail cost to ~25 ns/add (matching the main pair-tree). +// +// Bindings: +// binding 0: csr_indices — sorted point indices, 1-based (index 0 reserved). +// binding 1: tail_plan — three u32 per tail thread: +// [bucket_id, csr_start, count]. +// binding 2: point_pool — SoA-packed pool (2 planes, PG=2 vec4/elem). +// binding 3: bucket_sums — SoA-packed output (2 planes, PG=2 vec4/bucket), +// one packed point per bucket. Pre-zeroed by host. +// binding 4: params — params.x=T (tail thread count), +// params.y=N (pool size), +// params.z=B (bucket_sums slot count). +// +// Bounded loop: the per-thread accumulate loop iterates up to compile- +// time TAIL_CAP = 2*S - 1, breaking early when i >= count. No +// data-dependent unbounded loops. + +const TAIL_CAP: u32 = {{ tail_cap }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var csr_indices: array; +@group(0) @binding(1) var tail_plan: array; +@group(0) @binding(2) var point_pool: array>; +@group(0) @binding(3) var bucket_sums: array>; +@group(0) @binding(4) var params: vec4; + +fn load_pool(plane: u32, idx: u32, N: u32) -> BigInt { + let plane_base = plane * PG * N; + let base = plane_base + PG * idx; + let q0 = point_pool[base + 0u]; + let q1 = point_pool[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_bucket(plane: u32, b: u32, B: u32, val: ptr) { + let plane_base = plane * PG * B; + let base = plane_base + PG * b; + let w = pack_limbs_to_256(val); + bucket_sums[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + bucket_sums[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let N = params.y; + let B = params.z; + + let t = gid.x; + if (t >= T) { return; } + + let bucket_id = tail_plan[3u * t + 0u]; + let csr_start = tail_plan[3u * t + 1u]; + let count = tail_plan[3u * t + 2u]; + + if (count == 0u) { return; } + + var acc_x: BigInt = load_pool(0u, csr_indices[csr_start], N); + var acc_y: BigInt = load_pool(1u, csr_indices[csr_start], N); + + for (var i: u32 = 1u; i < TAIL_CAP; i = i + 1u) { + if (i >= count) { break; } + let pt_idx = csr_indices[csr_start + i]; + var p_x: BigInt = load_pool(0u, pt_idx, N); + var p_y: BigInt = load_pool(1u, pt_idx, N); + var dx: BigInt = fr_sub(&p_x, &acc_x); + var inv_dx: BigInt = fr_inv_by_a(dx); + var dy: BigInt = fr_sub(&p_y, &acc_y); + var lambda: BigInt = montgomery_product(&dy, &inv_dx); + var lambda_sq: BigInt = montgomery_product(&lambda, &lambda); + var r_x: BigInt = fr_sub(&lambda_sq, &acc_x); + r_x = fr_sub(&r_x, &p_x); + var r_y: BigInt = fr_sub(&acc_x, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &acc_y); + acc_x = r_x; + acc_y = r_y; + } + + store_bucket(0u, bucket_id, B, &acc_x); + store_bucket(1u, bucket_id, B, &acc_y); + + {{{ recompile }}} +} From 2ab550ff7f774c94a5442b8f092e3d7f962895b6 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 19 May 2026 23:51:21 +0000 Subject: [PATCH 12/33] fix(bb/msm/bench): stop pair-tree at one sum per bucket MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously the harness halved T until T==1, which over-iterates: after T*S reaches B_main (number of buckets in main path) we have one sum per bucket, and further halvings start summing across buckets — wasted work and incorrect aggregation. Stop at T = ceil(B_main / S). --- .../ts/dev/msm-webgpu/bench-msm-tree.ts | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts index cf07f011a8c6..5dfb1c0c59d1 100644 --- a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts @@ -420,15 +420,27 @@ async function runPipeline( ); if (T === 0 && TT === 0) throw new Error('plan is empty'); - // Determine number of pair-tree levels: starting at T_0 = T, halve - // each level until T == 0 or some safety cap. + // Determine number of pair-tree levels. Each level halves T and + // every level produces T*S outputs. We stop when T*S equals the + // distinct-bucket count contributing to the main path (= one sum + // per bucket). Iterating further would start pairing across + // buckets, which is incorrect. + let bMain = 0; + for (let b = 0; b < counts.length; b++) { + if (counts[b] >= 2 * S) bMain++; + } + if (bMain === 0) throw new Error('no buckets in main path'); + const stopT = Math.max(1, Math.ceil(bMain / S)); const levels: number[] = []; - for (let t = T; t > 0; t = Math.floor(t / 2)) { + for (let t = T; t >= stopT; t = Math.floor(t / 2)) { levels.push(t); - if (t === 1) break; + if (t === stopT) break; if (levels.length > 24) throw new Error('too many tree levels'); } - log('info', `pair-tree levels: ${levels.length} (T sequence: ${levels.join(' -> ')})`); + log( + 'info', + `pair-tree levels: ${levels.length} (T sequence: ${levels.join(' -> ')}, stopT=${stopT}, bMain=${bMain})`, + ); // Buffers. const mkSb = (size: number, copyDst: boolean, copySrc: boolean): GPUBuffer => { From 5bdbbf73028f007c48773c726e650bb323152523 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 19 May 2026 23:53:42 +0000 Subject: [PATCH 13/33] fix(bb/msm/bench): use high bits of LCG to get real Poisson skew MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous skewed-mode RNG `rng() % B` used the LCG's low 12 bits, which have a period of 4096 — exactly equal to B in our default config — so every bucket received exactly N/B points and the tail path was never exercised. Mix the high 16 bits of two LCG calls into the bucket index to get the intended uniform-random scalar assignment (and thus Poisson-skewed counts). --- barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts index 5dfb1c0c59d1..34115526fada 100644 --- a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts @@ -125,8 +125,13 @@ function buildUniformCSR(N: number, B: number, perBucket: number): CSR { function buildSkewedCSR(N: number, B: number, rng: () => number): CSR { const bucket = new Uint32Array(N); const counts = new Uint32Array(B); + // LCG low bits have short periods; mix high bits of two calls into + // the bucket index to get a real uniform-random scalar assignment + // (so Poisson-skewed counts, not perfectly-even-with-cyclic-RNG). for (let i = 0; i < N; i++) { - const b = rng() % B; + const hi = (rng() >>> 16) & 0xffff; + const lo = (rng() >>> 16) & 0xffff; + const b = ((hi << 16) | lo) % B; bucket[i] = b; counts[b]++; } From b84ccd645c413635716425e64edfbac437b316fc Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 19 May 2026 23:55:45 +0000 Subject: [PATCH 14/33] fix(bb/msm): add get_r to tail kernel + unsigned mul for skewed RNG MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two fixes from a failed BS M2 skewed-mode run: 1. ba_tail_reduce_bench.wgsl: define get_r() locally. An included partial (fr_pow_funcs or by_inverse_a_funcs) references get_r — without a local definition the WGSL compile fails with 'unresolved call target'. 2. bench-msm-tree.ts buildSkewedCSR: (hi << 16) | lo overflows to signed i32 in JS for hi >= 0x8000, producing a negative bucket index that silently corrupts the counts array (many points 'dropped'). Use hi * 0x10000 + lo for unsigned 32-bit composition. --- barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts | 12 ++++++++---- .../ts/src/msm_webgpu/wgsl/_generated/shaders.ts | 6 ++++++ .../wgsl/cuzk/ba_tail_reduce_bench.template.wgsl | 6 ++++++ 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts index 34115526fada..f8c05dfa79c3 100644 --- a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts @@ -125,13 +125,17 @@ function buildUniformCSR(N: number, B: number, perBucket: number): CSR { function buildSkewedCSR(N: number, B: number, rng: () => number): CSR { const bucket = new Uint32Array(N); const counts = new Uint32Array(B); - // LCG low bits have short periods; mix high bits of two calls into - // the bucket index to get a real uniform-random scalar assignment - // (so Poisson-skewed counts, not perfectly-even-with-cyclic-RNG). + // LCG low bits have short periods (low 12 bits cycle every 4096 + // calls, which equals B in the default config), so direct modulo + // produces a degenerate "every bucket gets exactly N/B points" + // distribution. Compose a uniform 32-bit integer from the high + // halves of two RNG draws via unsigned arithmetic (multiplication + // to avoid the signed-i32 quirk of JS bitwise ops) and reduce. for (let i = 0; i < N; i++) { const hi = (rng() >>> 16) & 0xffff; const lo = (rng() >>> 16) & 0xffff; - const b = ((hi << 16) | lo) % B; + const v = hi * 0x10000 + lo; + const b = v % B; bucket[i] = b; counts[b]++; } diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index c5c6b8206fb9..bd7c1907ffbd 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -2065,6 +2065,12 @@ fn store_bucket(plane: u32, b: u32, B: u32, val: ptr) { bucket_sums[base + 1u] = vec4(w[4], w[5], w[6], w[7]); } +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + @compute @workgroup_size({{ workgroup_size }}) fn main(@builtin(global_invocation_id) gid: vec3) { let T = params.x; diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_tail_reduce_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_tail_reduce_bench.template.wgsl index 67143057bfa3..06e0a57e26df 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_tail_reduce_bench.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_tail_reduce_bench.template.wgsl @@ -68,6 +68,12 @@ fn store_bucket(plane: u32, b: u32, B: u32, val: ptr) { bucket_sums[base + 1u] = vec4(w[4], w[5], w[6], w[7]); } +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + @compute @workgroup_size({{ workgroup_size }}) fn main(@builtin(global_invocation_id) gid: vec3) { let T = params.x; From 02c56012a5f3b77629f1b7a1dfc3e7f0589e7b47 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 09:30:54 +0000 Subject: [PATCH 15/33] =?UTF-8?q?feat(bb/msm):=20bin-packed=20pair-tree=20?= =?UTF-8?q?v2=20=E2=80=94=20uniform=20perf=20on=20Poisson=20skew?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Eliminates the tail kernel by packing pairs from any combination of buckets into the same chunk-thread. Pairs respect the within-pair bucket invariant (both operands from the same bucket); different chunk-slots can come from different buckets. Odd-count carries propagate forward to the next level unchanged. Three new kernels: - ba_marshal_pairs_bench: gather operands per chunk_plan from a generic active_sums buffer (L0 = bucket-sorted point pool, L1+ = previous level's sums+carries) into chain_buf strided layout. - ba_scatter_pairs_bench: write the disjoint kernel's simple-strided outputs to active_sums_new at per-bucket destinations specified by scatter_plan. - ba_carry_copy_bench: copy odd-count carry survivors forward without modification. Host harness bench-msm-tree-v2.{ts,html} drives the iterative pipeline: per-level: build plan (host) -> marshal-pairs -> tree-disjoint -> scatter-pairs -> carry-copy Terminates when max bucket count == 1. Tail kernel is no longer in the pipeline. shader_manager: three new gen_* methods. run-browserstack.mjs: pageMap entry "bench-msm-tree-v2". Expected: uniform ~30 ns/in-pt on Poisson(λ=32) — no tail penalty. --- .../ts/dev/msm-webgpu/bench-msm-tree-v2.html | 22 + .../ts/dev/msm-webgpu/bench-msm-tree-v2.ts | 555 ++++++++++++++++++ .../msm-webgpu/scripts/run-browserstack.mjs | 1 + .../ts/src/msm_webgpu/cuzk/shader_manager.ts | 52 ++ .../src/msm_webgpu/wgsl/_generated/shaders.ts | 202 ++++++- .../cuzk/ba_carry_copy_bench.template.wgsl | 54 ++ .../cuzk/ba_marshal_pairs_bench.template.wgsl | 79 +++ .../cuzk/ba_scatter_pairs_bench.template.wgsl | 61 ++ 8 files changed, 1025 insertions(+), 1 deletion(-) create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.html create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_carry_copy_bench.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_pairs_bench.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_scatter_pairs_bench.template.wgsl diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.html new file mode 100644 index 000000000000..bd33e10887af --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.html @@ -0,0 +1,22 @@ + + + + + Bin-packed pair-tree MSM bucket-accumulate v2 (WebGPU) + + + +

Bin-packed pair-tree MSM bucket-accumulate v2 (WebGPU)

+

Query params: ?reps=R&n=N&buckets=B&s=S&wgi=W

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts new file mode 100644 index 000000000000..f368411bd5f1 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts @@ -0,0 +1,555 @@ +/// +// bench-msm-tree-v2 — bin-packed pair-tree MSM bucket-accumulate with +// carry-forward. Eliminates the slow per-bucket tail kernel by packing +// pairs from any combination of buckets into the same chunk. +// +// For each (chunk t, slot k), the disjoint kernel sums (P_{2k}, P_{2k+1}). +// Both operands of each pair come from the SAME bucket; different +// (chunk, slot) entries can come from DIFFERENT buckets. The planner +// guarantees the within-pair bucket invariant. +// +// Per level transition: +// 1. host: per-bucket pair-count + carry, bin-pack into chunks of S. +// 2. marshal-pairs (GPU): gather operands per chunk_plan into chain_buf. +// 3. tree-disjoint (GPU, final=1): chain_buf -> simple strided output. +// 4. scatter-pairs (GPU): outputs -> active_sums_new at per-bucket positions. +// 5. carry-copy (GPU): odd-count carries -> active_sums_new (if any). +// 6. swap active_sums buffers, update counts/offsets. +// +// Terminate when max bucket count == 1. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { makeResultsClient } from './results_post.js'; + +const PG = 2; +const DEFAULT_N = 1 << 17; +const DEFAULT_BUCKETS = 1 << 12; +const DEFAULT_S = 16; +const DEFAULT_WGI = 64; + +let NPTS = DEFAULT_N; +let BUCKETS = DEFAULT_BUCKETS; +let S = DEFAULT_S; +let WGI = DEFAULT_WGI; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + for (;;) { + let v = 0n; + for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff); + v &= (1n << BigInt(bitlen)) - 1n; + if (v > 0n && v < p) return v; + } +} + +function bigintToPackedU32x8(v: bigint): Uint32Array { + const w = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + w[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return w; +} + +function makeSoABuf(device: GPUDevice, M: number, copyDst: boolean, copySrc: boolean): GPUBuffer { + const bytes = 2 * PG * M * 4 * 4; + let usage = GPUBufferUsage.STORAGE; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + return device.createBuffer({ size: bytes, usage }); +} + +// Build initial active_sums (Level 0). Points are assigned to random +// buckets and laid out bucket-major (bucket b's points at active_sums +// indices offsets[b] .. offsets[b]+counts[b]-1). The last 2 slots [M-2, +// M-1] hold a "pad pair" with distinct x — used to fill chunk-tail +// slots without divide-by-zero. +function buildL0ActiveSums(N: number, B: number, R: bigint, p: bigint, rng: () => number) { + const M = N + 2; + const buf = new Uint32Array(2 * PG * M * 4); + // Generate N + 2 random points. + const xWords = new Uint32Array(8 * M); + const yWords = new Uint32Array(8 * M); + for (let i = 0; i < M; i++) { + const x = (randomBelow(p, rng) * R) % p; + const y = (randomBelow(p, rng) * R) % p; + xWords.set(bigintToPackedU32x8(x), 8 * i); + yWords.set(bigintToPackedU32x8(y), 8 * i); + } + // Ensure pad pair x's differ. + if (xWords[8 * (M - 2)] === xWords[8 * (M - 1)]) { + xWords[8 * (M - 1)] ^= 1; + } + // Bucket assignment via composed hi/lo (unsigned). + const bucket = new Uint32Array(N); + const counts = new Uint32Array(B); + for (let i = 0; i < N; i++) { + const hi = (rng() >>> 16) & 0xffff; + const lo = (rng() >>> 16) & 0xffff; + const v = hi * 0x10000 + lo; + const b = v % B; + bucket[i] = b; + counts[b]++; + } + const offsets = new Uint32Array(B + 1); + for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b]; + const cursor = new Uint32Array(B); + const writeElem = (planeIdx: number, dstIdx: number, words: Uint32Array, srcOff: number) => { + for (let v = 0; v < PG; v++) { + const base = ((planeIdx * PG + v) * M + dstIdx) * 4; + buf[base + 0] = words[srcOff + 4 * v + 0]; + buf[base + 1] = words[srcOff + 4 * v + 1]; + buf[base + 2] = words[srcOff + 4 * v + 2]; + buf[base + 3] = words[srcOff + 4 * v + 3]; + } + }; + for (let i = 0; i < N; i++) { + const b = bucket[i]; + const dst = offsets[b] + cursor[b]++; + writeElem(0, dst, xWords, 8 * i); + writeElem(1, dst, yWords, 8 * i); + } + // Pad pair at indices M-2, M-1. + writeElem(0, M - 2, xWords, 8 * (M - 2)); + writeElem(1, M - 2, yWords, 8 * (M - 2)); + writeElem(0, M - 1, xWords, 8 * (M - 1)); + writeElem(1, M - 1, yWords, 8 * (M - 1)); + return { initBuf: buf, initCounts: counts, initOffsets: offsets, M }; +} + +// Bin-pack the per-bucket pairs into chunks of S. Returns the per-chunk +// operand-index plan and the per-output destination plan, plus the +// carry plan and next-level (counts, offsets). +function buildLevelPlan( + counts: Uint32Array, + offsets: Uint32Array, + s: number, + padLIdx: number, + padRIdx: number, + discardIdx: number, +) { + const B = counts.length; + let totalPairs = 0; + let totalCarries = 0; + const newCounts = new Uint32Array(B); + for (let b = 0; b < B; b++) { + const n = counts[b]; + const p = Math.floor(n / 2); + const c = n & 1; + totalPairs += p; + totalCarries += c; + newCounts[b] = p + c; + } + const newOffsets = new Uint32Array(B + 1); + for (let b = 0; b < B; b++) newOffsets[b + 1] = newOffsets[b] + newCounts[b]; + + const numChunks = Math.max(1, Math.ceil(totalPairs / s)); + const chunkPlan = new Uint32Array(2 * s * numChunks); + const scatterPlan = new Uint32Array(s * numChunks); + const carryPlan = new Uint32Array(2 * Math.max(1, totalCarries)); + + for (let i = 0; i < numChunks * s; i++) { + chunkPlan[2 * i + 0] = padLIdx; + chunkPlan[2 * i + 1] = padRIdx; + scatterPlan[i] = discardIdx; + } + + let slot = 0; + let carryIdx = 0; + for (let b = 0; b < B; b++) { + const n = counts[b]; + const p = Math.floor(n / 2); + for (let j = 0; j < p; j++) { + chunkPlan[2 * slot + 0] = offsets[b] + 2 * j; + chunkPlan[2 * slot + 1] = offsets[b] + 2 * j + 1; + scatterPlan[slot] = newOffsets[b] + j; + slot++; + } + if (n & 1) { + carryPlan[2 * carryIdx + 0] = offsets[b] + n - 1; + carryPlan[2 * carryIdx + 1] = newOffsets[b] + p; + carryIdx++; + } + } + return { chunkPlan, scatterPlan, carryPlan, newCounts, newOffsets, numChunks, numCarries: totalCarries, totalPairs }; +} + +interface LevelTiming { + T: number; + pairs: number; + carries: number; + marshal_ms: number; + disjoint_ms: number; + scatter_ms: number; + carry_ms: number; +} + +interface RunResult { + s: number; + wgi: number; + pairs: number; + buckets: number; + levels: number; + total_pair_adds: number; + total_wall_ms: number; + level_timings: LevelTiming[]; + ns_per_inpt: number; + sanity_ok: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { reps: number; n: number; buckets: number; s: number; wgi: number } | null; + results: RunResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] }; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const resultsClient = makeResultsClient({ page: 'bench-msm-tree-v2' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, params: benchState.params, results: benchState.results, + error: benchState.error, log: benchState.log, + userAgent: navigator.userAgent, hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-msm-tree-v2] ${msg}`); +} + +async function compileOne(device: GPUDevice, code: string, key: string, layout: GPUBindGroupLayout): Promise { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const m of info.messages) { + const line = `[shader ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`; + if (m.type === 'error') { console.error(line); log('err', line); errLines.push(line); hasError = true; } + else { console.warn(line); } + } + if (hasError) throw new Error(`WGSL compile failed for ${key}: ${errLines.slice(0, 4).join(' | ')}`); + return device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); +} + +function ioLayout4(device: GPUDevice): GPUBindGroupLayout { + return device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); +} + +async function readNonZero(device: GPUDevice, buf: GPUBuffer, u32Count: number): Promise { + const bytes = u32Count * 4; + const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, bytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const u32 = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + for (let i = 0; i < u32.length; i++) if (u32[i] !== 0) return true; + return false; +} + +async function timeOne( + device: GPUDevice, + pipeline: GPUComputePipeline, + bind: GPUBindGroup, + numWgs: number, +): Promise { + const enc = device.createCommandEncoder(); + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + return performance.now() - t0; +} + +async function runPipeline(device: GPUDevice, sm: ShaderManager, reps: number, R: bigint, p: bigint): Promise { + log('info', `=== N=${NPTS} B=${BUCKETS} S=${S} WGI=${WGI}`); + + const rng = makeRng(0x9111); + const { initBuf, initCounts, initOffsets, M } = buildL0ActiveSums(NPTS, BUCKETS, R, p, rng); + log('info', `built L0 active_sums: M=${M}`); + + // Histogram peek + let maxC0 = 0, minC0 = NPTS, c0 = 0, smallC = 0; + for (let b = 0; b < initCounts.length; b++) { + if (initCounts[b] > maxC0) maxC0 = initCounts[b]; + if (initCounts[b] < minC0) minC0 = initCounts[b]; + if (initCounts[b] === 0) c0++; + if (initCounts[b] < 32) smallC++; + } + log('info', `bucket counts: min=${minC0} max=${maxC0} zero=${c0} small(<32)=${smallC}/${initCounts.length}`); + + const padLIdx = M - 2; + const padRIdx = M - 1; + // Use a fixed discard slot: M-2 (same as padLIdx). The discarded output + // overwrites the pad pair on each level — fine because the pad pair is + // re-set by buildL0 once and never relied on for correctness; in + // subsequent levels the planner's pad selection still finds a valid + // distinct-x pair as long as M-2 and M-1 hold distinct-x data at start. + // Per level, we re-seed the pad slots by copying first two slots of + // active_sums_new... actually simpler: use NEW pad slots per level. We + // need slots that have distinct x in active_sums_new. For pure safety, + // we'll have the planner discard scatters always go to (M-2) and pad + // chunks always read from active_sums_OLD's (padLIdx, padRIdx) — which + // is still ping-pong-stable if we maintain pad slots in both buffers. + const discardIdx = M - 2; + + // Two ping-pong active_sums buffers, sized M. + const bufA = makeSoABuf(device, M, true, true); + const bufB = makeSoABuf(device, M, true, true); + device.queue.writeBuffer(bufA, 0, initBuf); + // Mirror the pad pair into bufB so it's available when we ping-pong. + // Read M-2 and M-1 from initBuf and write them into bufB at the same slots. + const padPairBytes = new Uint32Array(2 * PG * 2 * 4); + for (let pl = 0; pl < 2; pl++) { + for (let v = 0; v < PG; v++) { + const baseSrc = ((pl * PG + v) * M + (M - 2)) * 4; + const baseDst = (pl * PG + v) * 2 * 4 + 0; + padPairBytes[baseDst + 0] = initBuf[baseSrc + 0]; + padPairBytes[baseDst + 1] = initBuf[baseSrc + 1]; + padPairBytes[baseDst + 2] = initBuf[baseSrc + 2]; + padPairBytes[baseDst + 3] = initBuf[baseSrc + 3]; + const baseSrc2 = ((pl * PG + v) * M + (M - 1)) * 4; + padPairBytes[baseDst + 4] = initBuf[baseSrc2 + 0]; + padPairBytes[baseDst + 5] = initBuf[baseSrc2 + 1]; + padPairBytes[baseDst + 6] = initBuf[baseSrc2 + 2]; + padPairBytes[baseDst + 7] = initBuf[baseSrc2 + 3]; + } + } + // padPairBytes layout: same SoA but with M=2. Write into both bufA pad + // region and bufB pad region. bufA already has them via initBuf; for + // bufB write a sparse pad pair via a small upload at the pad offset. + // Simpler: write the entire initBuf into bufB too (initial state matches). + device.queue.writeBuffer(bufB, 0, initBuf); + + // Scratch buffers. + const maxL0Chunks = Math.ceil(NPTS / 2 / S) + 1; + const chainBuf = makeSoABuf(device, 2 * S * maxL0Chunks, false, false); + const tempOutBuf = makeSoABuf(device, S * maxL0Chunks, false, true); + + // Compile all 4 pipelines. + const layoutMarshal = ioLayout4(device); + const layoutDisjoint = ioLayout4(device); + const layoutScatter = ioLayout4(device); + const layoutCarry = ioLayout4(device); + const marshalPipe = await compileOne(device, sm.gen_ba_marshal_pairs_bench_shader(WGI, S), `marshal-pairs-W${WGI}-S${S}`, layoutMarshal); + const disjointPipe = await compileOne(device, sm.gen_ba_pair_disjoint_tree_bench_shader(WGI, S), `disjoint-W${WGI}-S${S}`, layoutDisjoint); + const scatterPipe = await compileOne(device, sm.gen_ba_scatter_pairs_bench_shader(WGI, S), `scatter-pairs-W${WGI}-S${S}`, layoutScatter); + const carryPipe = await compileOne(device, sm.gen_ba_carry_copy_bench_shader(WGI), `carry-W${WGI}`, layoutCarry); + log('info', '4 pipelines compiled'); + + // Iterate. + let counts = initCounts; + let offsets = initOffsets; + let curIn: GPUBuffer = bufA; + let curOut: GPUBuffer = bufB; + let totalPairAdds = 0; + let levelIdx = 0; + const levelTimings: LevelTiming[] = []; + const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE }); + + const startTime = performance.now(); + + for (;;) { + let maxCount = 0; + for (let b = 0; b < counts.length; b++) if (counts[b] > maxCount) maxCount = counts[b]; + if (maxCount <= 1) break; + if (levelIdx > 24) throw new Error('exceeded safety level cap'); + + const plan = buildLevelPlan(counts, offsets, S, padLIdx, padRIdx, discardIdx); + totalPairAdds += plan.totalPairs; + const T = plan.numChunks; + const numWgs = Math.ceil(T / WGI); + log('info', `L${levelIdx}: T=${T} pairs=${plan.totalPairs} carries=${plan.numCarries} maxCount=${maxCount}`); + + const chunkPlanBuf = device.createBuffer({ size: plan.chunkPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(chunkPlanBuf, 0, plan.chunkPlan); + const scatterPlanBuf = device.createBuffer({ size: plan.scatterPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(scatterPlanBuf, 0, plan.scatterPlan); + const carryPlanBuf = device.createBuffer({ size: plan.carryPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(carryPlanBuf, 0, plan.carryPlan); + + const marshalParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(marshalParams, 0, new Uint32Array([T, M, 0, 0])); + const disjointParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(disjointParams, 0, new Uint32Array([2 * S * T, T, 1, 0])); // final_flag=1 + const scatterParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(scatterParams, 0, new Uint32Array([T, M, 0, 0])); + const carryParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + if (plan.numCarries > 0) { + device.queue.writeBuffer(carryParams, 0, new Uint32Array([plan.numCarries, M, M, 0])); + } + + const marshalBind = device.createBindGroup({ + layout: layoutMarshal, entries: [ + { binding: 0, resource: { buffer: chunkPlanBuf } }, + { binding: 1, resource: { buffer: curIn } }, + { binding: 2, resource: { buffer: chainBuf } }, + { binding: 3, resource: { buffer: marshalParams } }, + ], + }); + const disjointBind = device.createBindGroup({ + layout: layoutDisjoint, entries: [ + { binding: 0, resource: { buffer: chainBuf } }, + { binding: 1, resource: { buffer: dummy } }, + { binding: 2, resource: { buffer: tempOutBuf } }, + { binding: 3, resource: { buffer: disjointParams } }, + ], + }); + const scatterBind = device.createBindGroup({ + layout: layoutScatter, entries: [ + { binding: 0, resource: { buffer: scatterPlanBuf } }, + { binding: 1, resource: { buffer: tempOutBuf } }, + { binding: 2, resource: { buffer: curOut } }, + { binding: 3, resource: { buffer: scatterParams } }, + ], + }); + let carryBind: GPUBindGroup | null = null; + if (plan.numCarries > 0) { + carryBind = device.createBindGroup({ + layout: layoutCarry, entries: [ + { binding: 0, resource: { buffer: carryPlanBuf } }, + { binding: 1, resource: { buffer: curIn } }, + { binding: 2, resource: { buffer: curOut } }, + { binding: 3, resource: { buffer: carryParams } }, + ], + }); + } + + // Warmup (untimed, optional). + // Timed sequential dispatch — each kernel awaited so we get per-kernel ms. + const marshalMs = await timeOne(device, marshalPipe, marshalBind, numWgs); + const disjointMs = await timeOne(device, disjointPipe, disjointBind, numWgs); + const scatterMs = await timeOne(device, scatterPipe, scatterBind, numWgs); + let carryMs = 0; + if (plan.numCarries > 0 && carryBind) { + const carryWgs = Math.ceil(plan.numCarries / WGI); + carryMs = await timeOne(device, carryPipe, carryBind, carryWgs); + } + + levelTimings.push({ + T, pairs: plan.totalPairs, carries: plan.numCarries, + marshal_ms: marshalMs, disjoint_ms: disjointMs, scatter_ms: scatterMs, carry_ms: carryMs, + }); + log('info', ` L${levelIdx} ms: marshal=${marshalMs.toFixed(2)} disjoint=${disjointMs.toFixed(2)} scatter=${scatterMs.toFixed(2)} carry=${carryMs.toFixed(2)}`); + + // Cleanup level-local buffers. + chunkPlanBuf.destroy(); + scatterPlanBuf.destroy(); + carryPlanBuf.destroy(); + marshalParams.destroy(); + disjointParams.destroy(); + scatterParams.destroy(); + carryParams.destroy(); + + counts = plan.newCounts; + offsets = plan.newOffsets; + [curIn, curOut] = [curOut, curIn]; + levelIdx++; + } + + const wall = performance.now() - startTime; + const sanity = await readNonZero(device, curIn, 8); + + bufA.destroy(); + bufB.destroy(); + chainBuf.destroy(); + tempOutBuf.destroy(); + dummy.destroy(); + + const nsPerInpt = (wall * 1e6) / NPTS; + log( + sanity ? 'ok' : 'err', + `pipeline: ${levelIdx} levels, ${totalPairAdds} pair-adds, total_wall=${wall.toFixed(2)}ms, ns/in-pt=${nsPerInpt.toFixed(2)}, sanity=${sanity ? 'OK' : 'FAIL'}`, + ); + + return { + s: S, wgi: WGI, pairs: NPTS, buckets: BUCKETS, levels: levelIdx, + total_pair_adds: totalPairAdds, total_wall_ms: wall, level_timings: levelTimings, + ns_per_inpt: nsPerInpt, sanity_ok: sanity, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + const reps = parseInt(qp.get('reps') ?? '3', 10); + if (!Number.isFinite(reps) || reps <= 0 || reps > 50) throw new Error(`?reps must be in (0, 50]`); + if (qp.get('n')) NPTS = parseInt(qp.get('n')!, 10); + if (qp.get('buckets')) BUCKETS = parseInt(qp.get('buckets')!, 10); + if (qp.get('s')) S = parseInt(qp.get('s')!, 10); + if (qp.get('wgi')) WGI = parseInt(qp.get('wgi')!, 10); + return { reps, n: NPTS, buckets: BUCKETS, s: S, wgi: WGI }; +} + +async function main() { + try { + if (!('gpu' in navigator)) throw new Error('navigator.gpu missing'); + const params = parseParams(); + benchState.params = params; + log('info', `params: reps=${params.reps} n=${params.n} buckets=${params.buckets} s=${params.s} wgi=${params.wgi}`); + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + const sm = new ShaderManager(4, NPTS, BN254_CURVE_CONFIG, false); + const r = await runPipeline(device, sm, params.reps, R, p); + benchState.results.push(r); + resultsClient.postProgress({ kind: 'pipeline_done', ns_per_inpt: r.ns_per_inpt, sanity_ok: r.sanity_ok }); + benchState.state = 'done'; + log('ok', 'done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { const msg = e instanceof Error ? e.message : String(e); log('err', `unhandled: ${msg}`); benchState.state = 'error'; benchState.error = msg; }) + .finally(() => { postFinal().catch(() => {}); }); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs index 2ad069d89548..6b0b0415e855 100644 --- a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs +++ b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs @@ -132,6 +132,7 @@ const pageMap = { "bench-msm-chain": "/dev/msm-webgpu/bench-msm-chain.html", "bench-ba-pair-disjoint": "/dev/msm-webgpu/bench-ba-pair-disjoint.html", "bench-msm-tree": "/dev/msm-webgpu/bench-msm-tree.html", + "bench-msm-tree-v2": "/dev/msm-webgpu/bench-msm-tree-v2.html", "bench-smvp-tree": "/dev/msm-webgpu/bench-smvp-tree.html", sanity: "/dev/msm-webgpu/index.html", }; diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index a114752776e5..922af1919034 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -5,10 +5,13 @@ import { batch_affine_apply as batch_affine_apply_shader, batch_affine_apply_scatter as batch_affine_apply_scatter_shader, batch_affine_dispatch_args as batch_affine_dispatch_args_shader, + ba_carry_copy_bench as ba_carry_copy_bench_shader, ba_marshal_chain_bench as ba_marshal_chain_bench_shader, + ba_marshal_pairs_bench as ba_marshal_pairs_bench_shader, ba_marshal_tree_l0_bench as ba_marshal_tree_l0_bench_shader, ba_pair_disjoint_bench as ba_pair_disjoint_bench_shader, ba_pair_disjoint_tree_bench as ba_pair_disjoint_tree_bench_shader, + ba_scatter_pairs_bench as ba_scatter_pairs_bench_shader, ba_tail_reduce_bench as ba_tail_reduce_bench_shader, ba_rev_packed_carry_bench as ba_rev_packed_carry_bench_shader, bench_batch_affine as bench_batch_affine_shader, @@ -786,6 +789,55 @@ ${packLines.join('\n')} ); } + /** + * Bin-packed pair-tree: marshal kernel that gathers operands from a + * generic active_sums buffer per chunk_plan. Works at any level + * (L0 active_sums = bucket-sorted point pool, L1+ = previous + * level's pair-sums + carries). + */ + public gen_ba_marshal_pairs_bench_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_marshal_pairs_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + return mustache.render( + ba_marshal_pairs_bench_shader, + { workgroup_size, s, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Bin-packed pair-tree: scatter kernel that places the disjoint + * kernel's strided outputs at per-bucket destinations in + * active_sums_new per scatter_plan. + */ + public gen_ba_scatter_pairs_bench_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_scatter_pairs_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + return mustache.render( + ba_scatter_pairs_bench_shader, + { workgroup_size, s, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Bin-packed pair-tree: carry-copy kernel. Propagates the odd-count + * carry element forward to the next level without modification. + * Pure memory shuffle. + */ + public gen_ba_carry_copy_bench_shader(workgroup_size: number): string { + if (workgroup_size <= 0 || !Number.isInteger(workgroup_size)) { + throw new Error(`gen_ba_carry_copy_bench_shader: workgroup_size (${workgroup_size}) must be a positive integer`); + } + return mustache.render( + ba_carry_copy_bench_shader, + { workgroup_size, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + /** * Tree variant of the disjoint pair-sum kernel: writes outputs in the * layout the next pair-tree level expects as input, so multi-level diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index bd7c1907ffbd..1840bfd8d7b3 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1,6 +1,6 @@ // AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT. // Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate. -// 56 shader sources inlined. +// 59 shader sources inlined. /* eslint-disable */ @@ -1351,6 +1351,62 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_carry_copy_bench = `{{> structs }} + +// Carry-copy kernel for the bin-packed pair-tree MSM bucket-accumulate. +// +// For each carry slot t, copies one packed (x, y) point from +// active_sums_old[carry_plan[2*t + 0]] to +// active_sums_new[carry_plan[2*t + 1]]. +// +// Used when a bucket has an odd active count at the current level: +// floor(N_b / 2) elements get paired and produce floor(N_b / 2) sums +// in the next level, plus the (N_b mod 2 == 1) carry element propagates +// forward unchanged. +// +// Pure memory shuffle, no field arithmetic. +// +// params.x = T (number of carry-copies / threads) +// params.y = M_old (active_sums_old size, vec4-stride scaling) +// params.z = M_new (active_sums_new size, vec4-stride scaling) + +const PG: u32 = 2u; + +@group(0) @binding(0) var carry_plan: array; +@group(0) @binding(1) var active_sums_old: array>; +@group(0) @binding(2) var active_sums_new: array>; +@group(0) @binding(3) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let M_old = params.y; + let M_new = params.z; + let t = gid.x; + if (t >= T) { return; } + + let src_idx = carry_plan[2u * t + 0u]; + let dst_idx = carry_plan[2u * t + 1u]; + + let old_plane_x = 0u * PG * M_old; + let old_plane_y = 1u * PG * M_old; + let new_plane_x = 0u * PG * M_new; + let new_plane_y = 1u * PG * M_new; + + let src_x = old_plane_x + PG * src_idx; + let src_y = old_plane_y + PG * src_idx; + let dst_x = new_plane_x + PG * dst_idx; + let dst_y = new_plane_y + PG * dst_idx; + + active_sums_new[dst_x + 0u] = active_sums_old[src_x + 0u]; + active_sums_new[dst_x + 1u] = active_sums_old[src_x + 1u]; + active_sums_new[dst_y + 0u] = active_sums_old[src_y + 0u]; + active_sums_new[dst_y + 1u] = active_sums_old[src_y + 1u]; + + {{{ recompile }}} +} +`; + export const ba_marshal_chain_bench = `{{> structs }} // Marshal kernel for the bench-msm-chain pipeline. Transposes a CSR @@ -1444,6 +1500,87 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_marshal_pairs_bench = `{{> structs }} + +// Marshal kernel for the bin-packed pair-tree MSM bucket-accumulate. +// +// Reads (idx_l, idx_r) operand indices per pair from chunk_plan, +// fetches the corresponding packed 8x u32 points from an active_sums +// buffer (2-plane SoA), and writes them into the disjoint kernel's +// strided input layout. +// +// Used both at level 0 (active_sums = bucket-sorted point pool) and +// at levels 1+ (active_sums = previous level's pair-sum + carry +// outputs). The kernel is bucket-agnostic; the planner has packed +// each chunk's S pairs from whatever buckets fit, and chunk_plan +// encodes the operand source indices. +// +// chunk_plan layout: 2 * S u32 per chunk +// chunk_plan[2 * (t * S + k) + 0] = idx_left (active_sums index) +// chunk_plan[2 * (t * S + k) + 1] = idx_right (active_sums index) +// +// active_sums layout: 2 planes (P.x, P.y), PG=2 vec4 per element, +// M_in elements per plane (params.y). +// +// chain_buf layout: 2 planes (P.x, P.y), PG=2 vec4 per element, +// 2 * S * T elements per plane. Slot (t, 2k+0) holds left, slot +// (t, 2k+1) holds right at the disjoint kernel's strided positions +// e = t + i * T for i = 2k, 2k+1. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var chunk_plan: array; +@group(0) @binding(1) var active_sums: array>; +@group(0) @binding(2) var chain_buf: array>; +@group(0) @binding(3) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let M_in = params.y; + let t = gid.x; + if (t >= T) { return; } + + let chain_N = 2u * S * T; + let chain_plane_x = 0u * PG * chain_N; + let chain_plane_y = 1u * PG * chain_N; + + let active_plane_x = 0u * PG * M_in; + let active_plane_y = 1u * PG * M_in; + + let chunk_base = 2u * S * t; + for (var k: u32 = 0u; k < S; k = k + 1u) { + let idx_l = chunk_plan[chunk_base + 2u * k + 0u]; + let idx_r = chunk_plan[chunk_base + 2u * k + 1u]; + + let e_l = t + (2u * k + 0u) * T; + let e_r = t + (2u * k + 1u) * T; + + let src_lx = active_plane_x + PG * idx_l; + let src_ly = active_plane_y + PG * idx_l; + let src_rx = active_plane_x + PG * idx_r; + let src_ry = active_plane_y + PG * idx_r; + + let dst_lx = chain_plane_x + PG * e_l; + let dst_ly = chain_plane_y + PG * e_l; + let dst_rx = chain_plane_x + PG * e_r; + let dst_ry = chain_plane_y + PG * e_r; + + chain_buf[dst_lx + 0u] = active_sums[src_lx + 0u]; + chain_buf[dst_lx + 1u] = active_sums[src_lx + 1u]; + chain_buf[dst_ly + 0u] = active_sums[src_ly + 0u]; + chain_buf[dst_ly + 1u] = active_sums[src_ly + 1u]; + chain_buf[dst_rx + 0u] = active_sums[src_rx + 0u]; + chain_buf[dst_rx + 1u] = active_sums[src_rx + 1u]; + chain_buf[dst_ry + 0u] = active_sums[src_ry + 0u]; + chain_buf[dst_ry + 1u] = active_sums[src_ry + 1u]; + } + + {{{ recompile }}} +} +`; + export const ba_marshal_tree_l0_bench = `{{> structs }} // Marshal kernel for the bench-msm-tree pair-tree pipeline: transposes @@ -1995,6 +2132,69 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_scatter_pairs_bench = `{{> structs }} + +// Scatter kernel for the bin-packed pair-tree MSM bucket-accumulate. +// +// For each (chunk t, slot k), reads R.x/R.y from the disjoint kernel's +// strided output (where it landed at flat index t + k * T after +// running with final_flag=1) and writes them to active_sums_new at +// the destination index given by scatter_plan[t * S + k]. +// +// This is the per-bucket-placement pass that re-groups pair sums for +// the next level's bin-packing planner. +// +// scatter_plan layout: 1 u32 per (chunk, slot). +// scatter_plan[t * S + k] = dst_idx (active_sums_new index) +// +// disjoint_out layout: 2 planes (R.x, R.y), PG=2 vec4 per element, +// S * T elements per plane (matches the disjoint kernel's +// final-mode simple strided write). +// +// active_sums_new layout: 2 planes (P.x, P.y), PG=2 vec4 per element, +// M_new elements per plane (params.y). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var scatter_plan: array; +@group(0) @binding(1) var disjoint_out: array>; +@group(0) @binding(2) var active_sums_new: array>; +@group(0) @binding(3) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let M_new = params.y; + let t = gid.x; + if (t >= T) { return; } + + let out_N = S * T; + let out_plane_x = 0u * PG * out_N; + let out_plane_y = 1u * PG * out_N; + + let new_plane_x = 0u * PG * M_new; + let new_plane_y = 1u * PG * M_new; + + for (var k: u32 = 0u; k < S; k = k + 1u) { + let e = t + k * T; + let dst_idx = scatter_plan[t * S + k]; + + let src_x = out_plane_x + PG * e; + let src_y = out_plane_y + PG * e; + let dst_x = new_plane_x + PG * dst_idx; + let dst_y = new_plane_y + PG * dst_idx; + + active_sums_new[dst_x + 0u] = disjoint_out[src_x + 0u]; + active_sums_new[dst_x + 1u] = disjoint_out[src_x + 1u]; + active_sums_new[dst_y + 0u] = disjoint_out[src_y + 0u]; + active_sums_new[dst_y + 1u] = disjoint_out[src_y + 1u]; + } + + {{{ recompile }}} +} +`; + export const ba_tail_reduce_bench = `{{> structs }} {{> bigint_funcs }} {{> montgomery_product_funcs }} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_carry_copy_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_carry_copy_bench.template.wgsl new file mode 100644 index 000000000000..50f409778d97 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_carry_copy_bench.template.wgsl @@ -0,0 +1,54 @@ +{{> structs }} + +// Carry-copy kernel for the bin-packed pair-tree MSM bucket-accumulate. +// +// For each carry slot t, copies one packed (x, y) point from +// active_sums_old[carry_plan[2*t + 0]] to +// active_sums_new[carry_plan[2*t + 1]]. +// +// Used when a bucket has an odd active count at the current level: +// floor(N_b / 2) elements get paired and produce floor(N_b / 2) sums +// in the next level, plus the (N_b mod 2 == 1) carry element propagates +// forward unchanged. +// +// Pure memory shuffle, no field arithmetic. +// +// params.x = T (number of carry-copies / threads) +// params.y = M_old (active_sums_old size, vec4-stride scaling) +// params.z = M_new (active_sums_new size, vec4-stride scaling) + +const PG: u32 = 2u; + +@group(0) @binding(0) var carry_plan: array; +@group(0) @binding(1) var active_sums_old: array>; +@group(0) @binding(2) var active_sums_new: array>; +@group(0) @binding(3) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let M_old = params.y; + let M_new = params.z; + let t = gid.x; + if (t >= T) { return; } + + let src_idx = carry_plan[2u * t + 0u]; + let dst_idx = carry_plan[2u * t + 1u]; + + let old_plane_x = 0u * PG * M_old; + let old_plane_y = 1u * PG * M_old; + let new_plane_x = 0u * PG * M_new; + let new_plane_y = 1u * PG * M_new; + + let src_x = old_plane_x + PG * src_idx; + let src_y = old_plane_y + PG * src_idx; + let dst_x = new_plane_x + PG * dst_idx; + let dst_y = new_plane_y + PG * dst_idx; + + active_sums_new[dst_x + 0u] = active_sums_old[src_x + 0u]; + active_sums_new[dst_x + 1u] = active_sums_old[src_x + 1u]; + active_sums_new[dst_y + 0u] = active_sums_old[src_y + 0u]; + active_sums_new[dst_y + 1u] = active_sums_old[src_y + 1u]; + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_pairs_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_pairs_bench.template.wgsl new file mode 100644 index 000000000000..a83210bc4ade --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_pairs_bench.template.wgsl @@ -0,0 +1,79 @@ +{{> structs }} + +// Marshal kernel for the bin-packed pair-tree MSM bucket-accumulate. +// +// Reads (idx_l, idx_r) operand indices per pair from chunk_plan, +// fetches the corresponding packed 8x u32 points from an active_sums +// buffer (2-plane SoA), and writes them into the disjoint kernel's +// strided input layout. +// +// Used both at level 0 (active_sums = bucket-sorted point pool) and +// at levels 1+ (active_sums = previous level's pair-sum + carry +// outputs). The kernel is bucket-agnostic; the planner has packed +// each chunk's S pairs from whatever buckets fit, and chunk_plan +// encodes the operand source indices. +// +// chunk_plan layout: 2 * S u32 per chunk +// chunk_plan[2 * (t * S + k) + 0] = idx_left (active_sums index) +// chunk_plan[2 * (t * S + k) + 1] = idx_right (active_sums index) +// +// active_sums layout: 2 planes (P.x, P.y), PG=2 vec4 per element, +// M_in elements per plane (params.y). +// +// chain_buf layout: 2 planes (P.x, P.y), PG=2 vec4 per element, +// 2 * S * T elements per plane. Slot (t, 2k+0) holds left, slot +// (t, 2k+1) holds right at the disjoint kernel's strided positions +// e = t + i * T for i = 2k, 2k+1. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var chunk_plan: array; +@group(0) @binding(1) var active_sums: array>; +@group(0) @binding(2) var chain_buf: array>; +@group(0) @binding(3) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let M_in = params.y; + let t = gid.x; + if (t >= T) { return; } + + let chain_N = 2u * S * T; + let chain_plane_x = 0u * PG * chain_N; + let chain_plane_y = 1u * PG * chain_N; + + let active_plane_x = 0u * PG * M_in; + let active_plane_y = 1u * PG * M_in; + + let chunk_base = 2u * S * t; + for (var k: u32 = 0u; k < S; k = k + 1u) { + let idx_l = chunk_plan[chunk_base + 2u * k + 0u]; + let idx_r = chunk_plan[chunk_base + 2u * k + 1u]; + + let e_l = t + (2u * k + 0u) * T; + let e_r = t + (2u * k + 1u) * T; + + let src_lx = active_plane_x + PG * idx_l; + let src_ly = active_plane_y + PG * idx_l; + let src_rx = active_plane_x + PG * idx_r; + let src_ry = active_plane_y + PG * idx_r; + + let dst_lx = chain_plane_x + PG * e_l; + let dst_ly = chain_plane_y + PG * e_l; + let dst_rx = chain_plane_x + PG * e_r; + let dst_ry = chain_plane_y + PG * e_r; + + chain_buf[dst_lx + 0u] = active_sums[src_lx + 0u]; + chain_buf[dst_lx + 1u] = active_sums[src_lx + 1u]; + chain_buf[dst_ly + 0u] = active_sums[src_ly + 0u]; + chain_buf[dst_ly + 1u] = active_sums[src_ly + 1u]; + chain_buf[dst_rx + 0u] = active_sums[src_rx + 0u]; + chain_buf[dst_rx + 1u] = active_sums[src_rx + 1u]; + chain_buf[dst_ry + 0u] = active_sums[src_ry + 0u]; + chain_buf[dst_ry + 1u] = active_sums[src_ry + 1u]; + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_scatter_pairs_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_scatter_pairs_bench.template.wgsl new file mode 100644 index 000000000000..00d14390002c --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_scatter_pairs_bench.template.wgsl @@ -0,0 +1,61 @@ +{{> structs }} + +// Scatter kernel for the bin-packed pair-tree MSM bucket-accumulate. +// +// For each (chunk t, slot k), reads R.x/R.y from the disjoint kernel's +// strided output (where it landed at flat index t + k * T after +// running with final_flag=1) and writes them to active_sums_new at +// the destination index given by scatter_plan[t * S + k]. +// +// This is the per-bucket-placement pass that re-groups pair sums for +// the next level's bin-packing planner. +// +// scatter_plan layout: 1 u32 per (chunk, slot). +// scatter_plan[t * S + k] = dst_idx (active_sums_new index) +// +// disjoint_out layout: 2 planes (R.x, R.y), PG=2 vec4 per element, +// S * T elements per plane (matches the disjoint kernel's +// final-mode simple strided write). +// +// active_sums_new layout: 2 planes (P.x, P.y), PG=2 vec4 per element, +// M_new elements per plane (params.y). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var scatter_plan: array; +@group(0) @binding(1) var disjoint_out: array>; +@group(0) @binding(2) var active_sums_new: array>; +@group(0) @binding(3) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let M_new = params.y; + let t = gid.x; + if (t >= T) { return; } + + let out_N = S * T; + let out_plane_x = 0u * PG * out_N; + let out_plane_y = 1u * PG * out_N; + + let new_plane_x = 0u * PG * M_new; + let new_plane_y = 1u * PG * M_new; + + for (var k: u32 = 0u; k < S; k = k + 1u) { + let e = t + k * T; + let dst_idx = scatter_plan[t * S + k]; + + let src_x = out_plane_x + PG * e; + let src_y = out_plane_y + PG * e; + let dst_x = new_plane_x + PG * dst_idx; + let dst_y = new_plane_y + PG * dst_idx; + + active_sums_new[dst_x + 0u] = disjoint_out[src_x + 0u]; + active_sums_new[dst_x + 1u] = disjoint_out[src_x + 1u]; + active_sums_new[dst_y + 0u] = disjoint_out[src_y + 0u]; + active_sums_new[dst_y + 1u] = disjoint_out[src_y + 1u]; + } + + {{{ recompile }}} +} From d927709e11cc6680560a8bbd4f600e5ae3fd26e3 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 09:34:41 +0000 Subject: [PATCH 16/33] perf(bb/msm/bench): batch v2 per-level dispatches into single submit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously v2 awaited between each of the 4 kernels per level — 24 total submit+drain overheads for 6 levels. Bundle each level's 4 kernels into one command encoder + one submit + one await. Per-level wall drops from sum-of-4-overheads to one-overhead. --- .../ts/dev/msm-webgpu/bench-msm-tree-v2.ts | 46 ++++++++++--------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts index f368411bd5f1..6caf186b4803 100644 --- a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts @@ -287,18 +287,20 @@ async function readNonZero(device: GPUDevice, buf: GPUBuffer, u32Count: number): return false; } -async function timeOne( - device: GPUDevice, - pipeline: GPUComputePipeline, - bind: GPUBindGroup, - numWgs: number, -): Promise { +interface PassSpec { pipeline: GPUComputePipeline; bind: GPUBindGroup; numWgs: number } + +// Encode multiple passes into one command encoder, submit once, await +// once. Returns the total wall time. Submit-overhead is paid once +// across all passes — the right way to measure a fused pipeline. +async function timeBatched(device: GPUDevice, passes: PassSpec[]): Promise { const enc = device.createCommandEncoder(); - const pass = enc.beginComputePass(); - pass.setPipeline(pipeline); - pass.setBindGroup(0, bind); - pass.dispatchWorkgroups(numWgs, 1, 1); - pass.end(); + for (const p of passes) { + const pass = enc.beginComputePass(); + pass.setPipeline(p.pipeline); + pass.setBindGroup(0, p.bind); + pass.dispatchWorkgroups(p.numWgs, 1, 1); + pass.end(); + } const t0 = performance.now(); device.queue.submit([enc.finish()]); await device.queue.onSubmittedWorkDone(); @@ -459,22 +461,24 @@ async function runPipeline(device: GPUDevice, sm: ShaderManager, reps: number, R }); } - // Warmup (untimed, optional). - // Timed sequential dispatch — each kernel awaited so we get per-kernel ms. - const marshalMs = await timeOne(device, marshalPipe, marshalBind, numWgs); - const disjointMs = await timeOne(device, disjointPipe, disjointBind, numWgs); - const scatterMs = await timeOne(device, scatterPipe, scatterBind, numWgs); - let carryMs = 0; + // Bundle this level's 4 kernel dispatches into a single command + // encoder + single submit + single await. Submit overhead amortises + // across the level's kernels. + const passes: PassSpec[] = [ + { pipeline: marshalPipe, bind: marshalBind, numWgs }, + { pipeline: disjointPipe, bind: disjointBind, numWgs }, + { pipeline: scatterPipe, bind: scatterBind, numWgs }, + ]; if (plan.numCarries > 0 && carryBind) { const carryWgs = Math.ceil(plan.numCarries / WGI); - carryMs = await timeOne(device, carryPipe, carryBind, carryWgs); + passes.push({ pipeline: carryPipe, bind: carryBind, numWgs: carryWgs }); } - + const levelMs = await timeBatched(device, passes); levelTimings.push({ T, pairs: plan.totalPairs, carries: plan.numCarries, - marshal_ms: marshalMs, disjoint_ms: disjointMs, scatter_ms: scatterMs, carry_ms: carryMs, + marshal_ms: 0, disjoint_ms: 0, scatter_ms: 0, carry_ms: 0, // unused — batched }); - log('info', ` L${levelIdx} ms: marshal=${marshalMs.toFixed(2)} disjoint=${disjointMs.toFixed(2)} scatter=${scatterMs.toFixed(2)} carry=${carryMs.toFixed(2)}`); + log('info', ` L${levelIdx} batched_ms=${levelMs.toFixed(2)} (4 kernels in one submit)`); // Cleanup level-local buffers. chunkPlanBuf.destroy(); From 0e7c898f7934b822fd0ef6748855f503fac41c86 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 09:37:21 +0000 Subject: [PATCH 17/33] perf(bb/msm/bench): single submit across all v2 levels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously each level's 4 kernels were batched, but each level was a separate submit + await — 6 levels x ~2ms submit overhead = 12ms. Encode ALL levels' passes into one command encoder and submit once. plan writeBuffers process before the submit on the device queue; GPU ensures storage barriers between dependent passes inside the encoder. Pure GPU wall time reported. --- .../ts/dev/msm-webgpu/bench-msm-tree-v2.ts | 39 +++++++++++-------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts index 6caf186b4803..b0f97f6081d8 100644 --- a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts @@ -392,6 +392,12 @@ async function runPipeline(device: GPUDevice, sm: ShaderManager, reps: number, R let levelIdx = 0; const levelTimings: LevelTiming[] = []; const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE }); + // Pre-pass: enqueue plan uploads and encode all-level passes. The + // device queue processes the writeBuffer calls in order before the + // single submit; the GPU inserts barriers between dependent storage + // reads/writes within the encoder. One submit + one await amortises + // submit overhead across the entire bucket-accumulate. + const allPasses: PassSpec[] = []; const startTime = performance.now(); @@ -461,24 +467,20 @@ async function runPipeline(device: GPUDevice, sm: ShaderManager, reps: number, R }); } - // Bundle this level's 4 kernel dispatches into a single command - // encoder + single submit + single await. Submit overhead amortises - // across the level's kernels. - const passes: PassSpec[] = [ - { pipeline: marshalPipe, bind: marshalBind, numWgs }, - { pipeline: disjointPipe, bind: disjointBind, numWgs }, - { pipeline: scatterPipe, bind: scatterBind, numWgs }, - ]; + // Stash level's passes into the outer pass list to be timed as a + // single batched submit across ALL levels. + allPasses.push({ pipeline: marshalPipe, bind: marshalBind, numWgs }); + allPasses.push({ pipeline: disjointPipe, bind: disjointBind, numWgs }); + allPasses.push({ pipeline: scatterPipe, bind: scatterBind, numWgs }); if (plan.numCarries > 0 && carryBind) { const carryWgs = Math.ceil(plan.numCarries / WGI); - passes.push({ pipeline: carryPipe, bind: carryBind, numWgs: carryWgs }); + allPasses.push({ pipeline: carryPipe, bind: carryBind, numWgs: carryWgs }); } - const levelMs = await timeBatched(device, passes); levelTimings.push({ T, pairs: plan.totalPairs, carries: plan.numCarries, - marshal_ms: 0, disjoint_ms: 0, scatter_ms: 0, carry_ms: 0, // unused — batched + marshal_ms: 0, disjoint_ms: 0, scatter_ms: 0, carry_ms: 0, }); - log('info', ` L${levelIdx} batched_ms=${levelMs.toFixed(2)} (4 kernels in one submit)`); + log('info', ` L${levelIdx} encoded (T=${T}, pairs=${plan.totalPairs})`); // Cleanup level-local buffers. chunkPlanBuf.destroy(); @@ -495,7 +497,12 @@ async function runPipeline(device: GPUDevice, sm: ShaderManager, reps: number, R levelIdx++; } - const wall = performance.now() - startTime; + // Single batched submit for ALL levels. + const wallSubmit = performance.now(); + const totalWall = await timeBatched(device, allPasses); + log('info', `batched ${allPasses.length} passes in one submit: ${totalWall.toFixed(2)}ms`); + + const wall = performance.now() - startTime; // includes plan-build + upload + GPU time const sanity = await readNonZero(device, curIn, 8); bufA.destroy(); @@ -504,15 +511,15 @@ async function runPipeline(device: GPUDevice, sm: ShaderManager, reps: number, R tempOutBuf.destroy(); dummy.destroy(); - const nsPerInpt = (wall * 1e6) / NPTS; + const nsPerInpt = (totalWall * 1e6) / NPTS; log( sanity ? 'ok' : 'err', - `pipeline: ${levelIdx} levels, ${totalPairAdds} pair-adds, total_wall=${wall.toFixed(2)}ms, ns/in-pt=${nsPerInpt.toFixed(2)}, sanity=${sanity ? 'OK' : 'FAIL'}`, + `pipeline: ${levelIdx} levels, ${totalPairAdds} pair-adds, single-submit GPU wall=${totalWall.toFixed(2)}ms, total incl plan-upload=${wall.toFixed(2)}ms, ns/in-pt=${nsPerInpt.toFixed(2)}, sanity=${sanity ? 'OK' : 'FAIL'}`, ); return { s: S, wgi: WGI, pairs: NPTS, buckets: BUCKETS, levels: levelIdx, - total_pair_adds: totalPairAdds, total_wall_ms: wall, level_timings: levelTimings, + total_pair_adds: totalPairAdds, total_wall_ms: totalWall, level_timings: levelTimings, ns_per_inpt: nsPerInpt, sanity_ok: sanity, }; } From 3491283b3a02ba7f379184cf0722339f25313092 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 11:46:48 +0000 Subject: [PATCH 18/33] =?UTF-8?q?feat(bb/msm):=20v3=20=E2=80=94=20fused=20?= =?UTF-8?q?super-kernel=20+=20GPU-side=20planner?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two algorithmic improvements compose into the v3 bucket-accumulate: 1. ba_fused_super_bench — combines v2's marshal + disjoint + scatter into one kernel per level. Chunk-thread reads chunk_plan + scatter_plan, gathers operands from active_sums_old, computes batched-inverse pair sums in registers, writes directly to active_sums_new at scatter destinations. Eliminates the chain_buf and tempOut scratch buffers entirely. 2. ba_planner_bench — GPU-side bin-packing planner. One thread per bucket; atomicAdd reserves global per-pair / per-carry / per-new- slot offsets; the thread writes its bucket's chunk_plan + scatter_ plan + carry_plan + new_counts + new_offsets entries directly. Host no longer round-trips between levels. Host harness bench-msm-tree-v3.{ts,html} drives the all-GPU pipeline. Over-dispatches LEVELS=8 (Poisson(λ=32) needs log2(60) ≈ 6); extra levels are no-ops at the kernel level (planner emits zero pairs). Per level: 3 kernels (planner, fused, carry). Down from v2's 4 (marshal, disjoint, scatter, carry). All passes encoded into a single command-encoder + single submit + single await. Zero CPU↔GPU sync during the entire bucket-accumulate. shader_manager: gen_ba_fused_super_bench_shader, gen_ba_planner_bench_shader. pageMap entry "bench-msm-tree-v3". Expected: ~17-20 ns/in-pt on M2 Poisson(λ=32), beating v2's 22.13. --- .../ts/dev/msm-webgpu/bench-msm-tree-v3.html | 22 + .../ts/dev/msm-webgpu/bench-msm-tree-v3.ts | 509 ++++++++++++++++++ .../msm-webgpu/scripts/run-browserstack.mjs | 1 + .../ts/src/msm_webgpu/cuzk/shader_manager.ts | 52 ++ .../src/msm_webgpu/wgsl/_generated/shaders.ts | 256 ++++++++- .../cuzk/ba_fused_super_bench.template.wgsl | 157 ++++++ .../wgsl/cuzk/ba_planner_bench.template.wgsl | 93 ++++ 7 files changed, 1089 insertions(+), 1 deletion(-) create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.html create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.ts create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_fused_super_bench.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_bench.template.wgsl diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.html new file mode 100644 index 000000000000..8eaa024c2f13 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.html @@ -0,0 +1,22 @@ + + + + + v3 — GPU planner + fused super-kernel MSM bucket-accumulate + + + +

v3 — GPU planner + fused super-kernel MSM bucket-accumulate

+

Query params: ?n=N&buckets=B&s=S&wgi=W&levels=L

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.ts new file mode 100644 index 000000000000..ea778518e1e2 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.ts @@ -0,0 +1,509 @@ +/// +// bench-msm-tree-v3 — GPU-side planner + fused super-kernel. +// +// Per level (all GPU, encoded into one command list): +// 1. Reset totals atomic counter to 0. +// 2. Pre-pad chunk_plan + scatter_plan + carry_plan to safe values. +// 3. Planner kernel: 1 thread per bucket -> writes chunk_plan, +// scatter_plan, carry_plan, new_counts, new_offsets via atomic +// offset reservation. +// 4. Fused super-kernel: marshal + disjoint + scatter in one pass. +// Reads chunk_plan + scatter_plan + active_sums_old; writes +// active_sums_new. +// 5. Carry kernel: copies odd-count carries from active_sums_old to +// active_sums_new. +// 6. Swap active_sums buffers, swap counts/offsets buffers. +// +// Over-dispatch L_MAX levels (default 8 for Poisson(λ=32) where max +// bucket count is ~50-60, requiring log2(60) = 6 levels). Extra levels +// with all-count-1 input are no-ops at the kernel level (planner +// produces zero pairs; fused kernel dispatched with 0 threads via +// host-side numWgs=0; carry kernel just copies the single-element- +// per-bucket data forward). +// +// Single submit across all 3*L_MAX kernel dispatches. Zero host-GPU +// round-trips between scalar-decompose and final readback. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { makeResultsClient } from './results_post.js'; + +const PG = 2; +const DEFAULT_N = 1 << 17; +const DEFAULT_BUCKETS = 1 << 12; +const DEFAULT_S = 16; +const DEFAULT_WGI = 64; +const DEFAULT_LEVELS = 8; + +let NPTS = DEFAULT_N; +let BUCKETS = DEFAULT_BUCKETS; +let S = DEFAULT_S; +let WGI = DEFAULT_WGI; +let LEVELS = DEFAULT_LEVELS; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + for (;;) { + let v = 0n; + for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff); + v &= (1n << BigInt(bitlen)) - 1n; + if (v > 0n && v < p) return v; + } +} + +function bigintToPackedU32x8(v: bigint): Uint32Array { + const w = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + w[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return w; +} + +// Build initial active_sums and per-bucket counts/offsets (level 0). +function buildL0(N: number, B: number, R: bigint, p: bigint, rng: () => number) { + const M = N + 2; + const buf = new Uint32Array(2 * PG * M * 4); + const xWords = new Uint32Array(8 * M); + const yWords = new Uint32Array(8 * M); + for (let i = 0; i < M; i++) { + const x = (randomBelow(p, rng) * R) % p; + const y = (randomBelow(p, rng) * R) % p; + xWords.set(bigintToPackedU32x8(x), 8 * i); + yWords.set(bigintToPackedU32x8(y), 8 * i); + } + if (xWords[8 * (M - 2)] === xWords[8 * (M - 1)]) xWords[8 * (M - 1)] ^= 1; + const bucket = new Uint32Array(N); + const counts = new Uint32Array(B); + for (let i = 0; i < N; i++) { + const hi = (rng() >>> 16) & 0xffff; + const lo = (rng() >>> 16) & 0xffff; + const v = hi * 0x10000 + lo; + const b = v % B; + bucket[i] = b; + counts[b]++; + } + const offsets = new Uint32Array(B + 1); + for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b]; + const cursor = new Uint32Array(B); + const writeElem = (planeIdx: number, dstIdx: number, words: Uint32Array, srcOff: number) => { + for (let v = 0; v < PG; v++) { + const base = ((planeIdx * PG + v) * M + dstIdx) * 4; + buf[base + 0] = words[srcOff + 4 * v + 0]; + buf[base + 1] = words[srcOff + 4 * v + 1]; + buf[base + 2] = words[srcOff + 4 * v + 2]; + buf[base + 3] = words[srcOff + 4 * v + 3]; + } + }; + for (let i = 0; i < N; i++) { + const b = bucket[i]; + const dst = offsets[b] + cursor[b]++; + writeElem(0, dst, xWords, 8 * i); + writeElem(1, dst, yWords, 8 * i); + } + writeElem(0, M - 2, xWords, 8 * (M - 2)); + writeElem(1, M - 2, yWords, 8 * (M - 2)); + writeElem(0, M - 1, xWords, 8 * (M - 1)); + writeElem(1, M - 1, yWords, 8 * (M - 1)); + return { initBuf: buf, initCounts: counts, initOffsets: offsets, M }; +} + +function makeSoABuf(device: GPUDevice, M: number, copyDst: boolean, copySrc: boolean): GPUBuffer { + const bytes = 2 * PG * M * 4 * 4; + let usage = GPUBufferUsage.STORAGE; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + return device.createBuffer({ size: bytes, usage }); +} + +interface RunResult { + s: number; + wgi: number; + pairs: number; + buckets: number; + levels_run: number; + gpu_wall_ms: number; + ns_per_inpt: number; + sanity_ok: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { n: number; buckets: number; s: number; wgi: number; levels: number } | null; + results: RunResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] }; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const resultsClient = makeResultsClient({ page: 'bench-msm-tree-v3' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, params: benchState.params, results: benchState.results, + error: benchState.error, log: benchState.log, + userAgent: navigator.userAgent, hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-msm-tree-v3] ${msg}`); +} + +async function compileOne(device: GPUDevice, code: string, key: string, layout: GPUBindGroupLayout): Promise { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const m of info.messages) { + const line = `[shader ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`; + if (m.type === 'error') { console.error(line); log('err', line); errLines.push(line); hasError = true; } + else { console.warn(line); } + } + if (hasError) throw new Error(`WGSL compile failed for ${key}: ${errLines.slice(0, 4).join(' | ')}`); + return device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); +} + +async function readNonZero(device: GPUDevice, buf: GPUBuffer, u32Count: number): Promise { + const bytes = u32Count * 4; + const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, bytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const u32 = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + for (let i = 0; i < u32.length; i++) if (u32[i] !== 0) return true; + return false; +} + +async function runPipeline(device: GPUDevice, sm: ShaderManager, R: bigint, p: bigint): Promise { + log('info', `=== N=${NPTS} B=${BUCKETS} S=${S} WGI=${WGI} LEVELS=${LEVELS}`); + + const rng = makeRng(0xc711); + const { initBuf, initCounts, initOffsets, M } = buildL0(NPTS, BUCKETS, R, p, rng); + log('info', `L0 active_sums: M=${M}, B=${BUCKETS}`); + + // Histogram peek + let maxC = 0, minC = NPTS, smallC = 0; + for (let b = 0; b < initCounts.length; b++) { + if (initCounts[b] > maxC) maxC = initCounts[b]; + if (initCounts[b] < minC) minC = initCounts[b]; + if (initCounts[b] < 32) smallC++; + } + log('info', `bucket counts: min=${minC} max=${maxC} small(<32)=${smallC}/${BUCKETS}`); + + // Plan-buffer sizing — must accommodate L0 max chunks. + // At L0: total pairs <= N/2 = 65536; max chunks = ceil(65536/S) = 4096. + // At deeper levels: shrinks. So max-allocated for L0. + const MAX_CHUNKS = Math.ceil(NPTS / 2 / S) + 16; // pad + const MAX_PAIR_SLOTS = MAX_CHUNKS * S; + const MAX_CARRIES = BUCKETS; // at most one carry per bucket + + const mkStorage = (bytes: number, copyDst = true, copySrc = false): GPUBuffer => { + let usage = GPUBufferUsage.STORAGE; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + return device.createBuffer({ size: bytes, usage }); + }; + + // Ping-pong active_sums. + const bufA = makeSoABuf(device, M, true, true); + const bufB = makeSoABuf(device, M, true, true); + device.queue.writeBuffer(bufA, 0, initBuf); + device.queue.writeBuffer(bufB, 0, initBuf); // mirror initial for pad-pair availability + + // Plan buffers (reused per level). + const chunkPlanBuf = mkStorage(2 * MAX_PAIR_SLOTS * 4); + const scatterPlanBuf = mkStorage(MAX_PAIR_SLOTS * 4); + const carryPlanBuf = mkStorage(2 * MAX_CARRIES * 4); + + // Per-level counts/offsets buffers (ping-pong). + const countsA = mkStorage(BUCKETS * 4); + const countsB = mkStorage(BUCKETS * 4); + const offsetsA = mkStorage((BUCKETS + 1) * 4); + const offsetsB = mkStorage((BUCKETS + 1) * 4); + device.queue.writeBuffer(countsA, 0, initCounts); + device.queue.writeBuffer(offsetsA, 0, initOffsets); + + // Totals atomic counter [pair_off_accum, carry_off_accum, new_off_accum, _] + const totalsBuf = device.createBuffer({ + size: 16, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + + // Pre-padded chunk_plan / scatter_plan / carry_plan templates. + // Pad slots all point to safe values: + // chunk_plan pad pair = (M-2, M-1) — known distinct-x in active_sums + // scatter_plan pad dst = M-2 — discard target (within active_sums_new) + // carry_plan pad src = M-2, dst = M-2 — no-op self-copy of pad slot + const padChunkPlan = new Uint32Array(2 * MAX_PAIR_SLOTS); + const padScatterPlan = new Uint32Array(MAX_PAIR_SLOTS); + const padCarryPlan = new Uint32Array(2 * MAX_CARRIES); + for (let i = 0; i < MAX_PAIR_SLOTS; i++) { + padChunkPlan[2 * i + 0] = M - 2; + padChunkPlan[2 * i + 1] = M - 1; + padScatterPlan[i] = M - 2; + } + for (let i = 0; i < MAX_CARRIES; i++) { + padCarryPlan[2 * i + 0] = M - 2; + padCarryPlan[2 * i + 1] = M - 2; + } + + // Per-level params (reused). + const plannerParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + const fusedParams: GPUBuffer[] = []; + const carryParams: GPUBuffer[] = []; + for (let i = 0; i < LEVELS; i++) { + fusedParams.push(device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST })); + carryParams.push(device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST })); + } + device.queue.writeBuffer(plannerParams, 0, new Uint32Array([BUCKETS, S, 0, 0])); + + // Layouts. + const plannerLayout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 5, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 6, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 7, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 8, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const fusedLayout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const carryLayout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + + const plannerPipe = await compileOne(device, sm.gen_ba_planner_bench_shader(WGI, S, 64), `planner-W${WGI}-S${S}`, plannerLayout); + const fusedPipe = await compileOne(device, sm.gen_ba_fused_super_bench_shader(WGI, S), `fused-W${WGI}-S${S}`, fusedLayout); + const carryPipe = await compileOne(device, sm.gen_ba_carry_copy_bench_shader(WGI), `carry-W${WGI}`, carryLayout); + log('info', '3 pipelines compiled'); + + // Encode all level passes into one command encoder. + // Per level k: + // - clear totalsBuf to zero (writeBuffer queued before submit) + // - clear plan buffers to pad templates (writeBuffer queued) + // - dispatch planner + // - dispatch fused (numWgs = ceil(MAX_CHUNKS / WGI) — over-provisioned; + // idle threads early-out via if (t >= T) check) + // - dispatch carry (numWgs = ceil(MAX_CARRIES / WGI) — over-provisioned) + // - swap counts/offsets via bind group selection on next level + + const enc = device.createCommandEncoder(); + let curCountsIn: GPUBuffer = countsA; + let curCountsOut: GPUBuffer = countsB; + let curOffsetsIn: GPUBuffer = offsetsA; + let curOffsetsOut: GPUBuffer = offsetsB; + let curActiveIn: GPUBuffer = bufA; + let curActiveOut: GPUBuffer = bufB; + + const numWgsPlanner = Math.ceil(BUCKETS / WGI); + const numWgsFused = Math.ceil(MAX_CHUNKS / WGI); + const numWgsCarry = Math.ceil(MAX_CARRIES / WGI); + log('info', `dispatch sizes: planner=${numWgsPlanner} fused=${numWgsFused} carry=${numWgsCarry}`); + + // Pre-write per-level params (since they depend on iteration index). + for (let lv = 0; lv < LEVELS; lv++) { + // params.x = T_fused = MAX_CHUNKS (over-provisioned; planner-written + // chunk_plan/scatter_plan trailers point to pad => safe early-out + // is implicit because pads do harmless add+discard) + device.queue.writeBuffer(fusedParams[lv], 0, new Uint32Array([MAX_CHUNKS, M, M, 0])); + device.queue.writeBuffer(carryParams[lv], 0, new Uint32Array([MAX_CARRIES, M, M, 0])); + } + + // Pre-pad plan buffers (only need to do once; planner overwrites real + // entries each level, and the pre-pad is stable across levels because + // pad slots remain pad-valued). + device.queue.writeBuffer(chunkPlanBuf, 0, padChunkPlan); + device.queue.writeBuffer(scatterPlanBuf, 0, padScatterPlan); + device.queue.writeBuffer(carryPlanBuf, 0, padCarryPlan); + + // Bind groups built per level (to swap counts/offsets/active buffers). + for (let lv = 0; lv < LEVELS; lv++) { + // Reset totals atomic counter before this level's planner. + device.queue.writeBuffer(totalsBuf, 0, new Uint32Array([0, 0, 0, 0])); + // Re-pad plan buffers (planner overwrites only real entries; the + // pad regions get re-padded between levels to clean any leftover + // real entries from the prior level). + device.queue.writeBuffer(chunkPlanBuf, 0, padChunkPlan); + device.queue.writeBuffer(scatterPlanBuf, 0, padScatterPlan); + device.queue.writeBuffer(carryPlanBuf, 0, padCarryPlan); + + const plannerBind = device.createBindGroup({ + layout: plannerLayout, + entries: [ + { binding: 0, resource: { buffer: curCountsIn } }, + { binding: 1, resource: { buffer: curOffsetsIn } }, + { binding: 2, resource: { buffer: chunkPlanBuf } }, + { binding: 3, resource: { buffer: scatterPlanBuf } }, + { binding: 4, resource: { buffer: carryPlanBuf } }, + { binding: 5, resource: { buffer: totalsBuf } }, + { binding: 6, resource: { buffer: curCountsOut } }, + { binding: 7, resource: { buffer: curOffsetsOut } }, + { binding: 8, resource: { buffer: plannerParams } }, + ], + }); + const fusedBind = device.createBindGroup({ + layout: fusedLayout, + entries: [ + { binding: 0, resource: { buffer: chunkPlanBuf } }, + { binding: 1, resource: { buffer: scatterPlanBuf } }, + { binding: 2, resource: { buffer: curActiveIn } }, + { binding: 3, resource: { buffer: curActiveOut } }, + { binding: 4, resource: { buffer: fusedParams[lv] } }, + ], + }); + const carryBind = device.createBindGroup({ + layout: carryLayout, + entries: [ + { binding: 0, resource: { buffer: carryPlanBuf } }, + { binding: 1, resource: { buffer: curActiveIn } }, + { binding: 2, resource: { buffer: curActiveOut } }, + { binding: 3, resource: { buffer: carryParams[lv] } }, + ], + }); + + // Encode the 3 passes for this level. + { + const pass = enc.beginComputePass(); + pass.setPipeline(plannerPipe); + pass.setBindGroup(0, plannerBind); + pass.dispatchWorkgroups(numWgsPlanner, 1, 1); + pass.end(); + } + { + const pass = enc.beginComputePass(); + pass.setPipeline(fusedPipe); + pass.setBindGroup(0, fusedBind); + pass.dispatchWorkgroups(numWgsFused, 1, 1); + pass.end(); + } + { + const pass = enc.beginComputePass(); + pass.setPipeline(carryPipe); + pass.setBindGroup(0, carryBind); + pass.dispatchWorkgroups(numWgsCarry, 1, 1); + pass.end(); + } + + // Swap for next level. + [curCountsIn, curCountsOut] = [curCountsOut, curCountsIn]; + [curOffsetsIn, curOffsetsOut] = [curOffsetsOut, curOffsetsIn]; + [curActiveIn, curActiveOut] = [curActiveOut, curActiveIn]; + } + + // Single submit + single await for the entire bucket-accumulate. + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + const gpuWall = performance.now() - t0; + + // Sanity: at least one element of the final active_sums must be non-zero. + const sanity = await readNonZero(device, curActiveIn, 8); + const nsPerInpt = (gpuWall * 1e6) / NPTS; + + log( + sanity ? 'ok' : 'err', + `v3 pipeline: ${LEVELS} levels over-dispatched, single submit GPU wall=${gpuWall.toFixed(2)}ms, ns/in-pt=${nsPerInpt.toFixed(2)}, sanity=${sanity ? 'OK' : 'FAIL'}`, + ); + + // Cleanup + bufA.destroy(); bufB.destroy(); + chunkPlanBuf.destroy(); scatterPlanBuf.destroy(); carryPlanBuf.destroy(); + countsA.destroy(); countsB.destroy(); offsetsA.destroy(); offsetsB.destroy(); + totalsBuf.destroy(); + plannerParams.destroy(); + for (const b of fusedParams) b.destroy(); + for (const b of carryParams) b.destroy(); + + return { + s: S, wgi: WGI, pairs: NPTS, buckets: BUCKETS, + levels_run: LEVELS, gpu_wall_ms: gpuWall, ns_per_inpt: nsPerInpt, sanity_ok: sanity, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + if (qp.get('n')) NPTS = parseInt(qp.get('n')!, 10); + if (qp.get('buckets')) BUCKETS = parseInt(qp.get('buckets')!, 10); + if (qp.get('s')) S = parseInt(qp.get('s')!, 10); + if (qp.get('wgi')) WGI = parseInt(qp.get('wgi')!, 10); + if (qp.get('levels')) LEVELS = parseInt(qp.get('levels')!, 10); + return { n: NPTS, buckets: BUCKETS, s: S, wgi: WGI, levels: LEVELS }; +} + +async function main() { + try { + if (!('gpu' in navigator)) throw new Error('navigator.gpu missing'); + const params = parseParams(); + benchState.params = params; + log('info', `params: n=${params.n} buckets=${params.buckets} s=${params.s} wgi=${params.wgi} levels=${params.levels}`); + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + const sm = new ShaderManager(4, NPTS, BN254_CURVE_CONFIG, false); + const r = await runPipeline(device, sm, R, p); + benchState.results.push(r); + resultsClient.postProgress({ kind: 'pipeline_done', ns_per_inpt: r.ns_per_inpt, sanity_ok: r.sanity_ok }); + benchState.state = 'done'; + log('ok', 'done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { const msg = e instanceof Error ? e.message : String(e); log('err', `unhandled: ${msg}`); benchState.state = 'error'; benchState.error = msg; }) + .finally(() => { postFinal().catch(() => {}); }); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs index 6b0b0415e855..b24956e6ca9d 100644 --- a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs +++ b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs @@ -133,6 +133,7 @@ const pageMap = { "bench-ba-pair-disjoint": "/dev/msm-webgpu/bench-ba-pair-disjoint.html", "bench-msm-tree": "/dev/msm-webgpu/bench-msm-tree.html", "bench-msm-tree-v2": "/dev/msm-webgpu/bench-msm-tree-v2.html", + "bench-msm-tree-v3": "/dev/msm-webgpu/bench-msm-tree-v3.html", "bench-smvp-tree": "/dev/msm-webgpu/bench-smvp-tree.html", sanity: "/dev/msm-webgpu/index.html", }; diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index 922af1919034..b30b52bf3a48 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -6,11 +6,13 @@ import { batch_affine_apply_scatter as batch_affine_apply_scatter_shader, batch_affine_dispatch_args as batch_affine_dispatch_args_shader, ba_carry_copy_bench as ba_carry_copy_bench_shader, + ba_fused_super_bench as ba_fused_super_bench_shader, ba_marshal_chain_bench as ba_marshal_chain_bench_shader, ba_marshal_pairs_bench as ba_marshal_pairs_bench_shader, ba_marshal_tree_l0_bench as ba_marshal_tree_l0_bench_shader, ba_pair_disjoint_bench as ba_pair_disjoint_bench_shader, ba_pair_disjoint_tree_bench as ba_pair_disjoint_tree_bench_shader, + ba_planner_bench as ba_planner_bench_shader, ba_scatter_pairs_bench as ba_scatter_pairs_bench_shader, ba_tail_reduce_bench as ba_tail_reduce_bench_shader, ba_rev_packed_carry_bench as ba_rev_packed_carry_bench_shader, @@ -789,6 +791,56 @@ ${packLines.join('\n')} ); } + /** + * Fused super-kernel for the v3 pipeline: combines marshal + disjoint + * + scatter into one kernel. Per chunk-thread: reads chunk_plan + + * scatter_plan, gathers from active_sums_old, computes batched- + * inverse pair sums in registers, writes directly to active_sums_new. + * Carry is a separate kernel. + */ + public gen_ba_fused_super_bench_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_fused_super_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + const dec = this.decoupledPackUnpackWgsl(); + return mustache.render( + ba_fused_super_bench_shader, + { + workgroup_size, s, + word_size: this.word_size, num_words: this.num_words, n0: this.n0, + p_limbs: this.p_limbs, r_limbs: this.r_limbs, r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, mask: this.mask, + two_pow_word_size: this.two_pow_word_size, p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, + dec_unpack: dec.unpack, dec_pack: dec.pack, recompile: this.recompile, + }, + { + structs, bigint_funcs, + montgomery_product_funcs: this.mont_product_src, + field_funcs, fr_pow_funcs, bigint_by_funcs, by_inverse_a_funcs, + }, + ); + } + + /** + * GPU-side bin-packing planner for the v3 pipeline. One thread per + * bucket; atomicAdd reserves global per-pair / per-carry / per-new- + * slot offsets; the thread writes chunk_plan + scatter_plan + + * carry_plan + new_counts + new_offsets for its bucket. Pair-count + * loop bounded by compile-time `pair_cap` (defaults to 64 — covers + * Poisson(λ=32) max-count without issue). + */ + public gen_ba_planner_bench_shader(workgroup_size: number, s: number, pair_cap: number = 64): string { + if (workgroup_size <= 0 || s <= 0 || pair_cap <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s) || !Number.isInteger(pair_cap)) { + throw new Error(`gen_ba_planner_bench_shader: workgroup_size (${workgroup_size}), s (${s}), and pair_cap (${pair_cap}) must be positive integers`); + } + return mustache.render( + ba_planner_bench_shader, + { workgroup_size, s, pair_cap, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + /** * Bin-packed pair-tree: marshal kernel that gathers operands from a * generic active_sums buffer per chunk_plan. Works at any level diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index 1840bfd8d7b3..0442dd802a74 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1,6 +1,6 @@ // AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT. // Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate. -// 59 shader sources inlined. +// 61 shader sources inlined. /* eslint-disable */ @@ -1407,6 +1407,165 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_fused_super_bench = `{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Fused super-kernel for the bin-packed pair-tree MSM bucket-accumulate. +// +// Combines marshal + disjoint + scatter into one kernel. Each thread t +// handles one chunk of S pairs: +// 1. Read 2*S source indices from chunk_plan (idx_l, idx_r per slot). +// 2. Read S destination indices from scatter_plan. +// 3. Load S pair-x values from active_sums_old, compute S dx values +// and forward prefix product, all in registers. +// 4. Single fr_inv_by_a on the prefix product. +// 5. Backward peel: per slot k from S-1 down to 0: +// - load .x and .y for both operands +// - lean affine add -> R_x, R_y +// - write directly to active_sums_new at scatter_plan[t*S + k] +// - update inv for next (smaller-k) iteration +// +// vs v2 (4 kernels: marshal, disjoint, scatter, carry): the chain_buf +// and tempOut scratch buffers are eliminated. All intermediate state +// lives in registers. Per-level dispatch count drops from 4 to 2 +// (fused + carry). +// +// PARAMS: +// params.x = T_chunks (active threads, one per chunk) +// params.y = M_old (active_sums_old vec4-stride length) +// params.z = M_new (active_sums_new vec4-stride length) +// +// Layout (both active_sums buffers): 2 planes (P.x, P.y), PG=2 vec4 per +// element. plane_p flat vec4 base = p * PG * M, element e at offset +// PG * e. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var chunk_plan: array; +@group(0) @binding(1) var scatter_plan: array; +@group(0) @binding(2) var active_sums_old: array>; +@group(0) @binding(3) var active_sums_new: array>; +@group(0) @binding(4) var params: vec4; + +fn load_active_x(idx: u32, M: u32) -> BigInt { + let plane_base = 0u * PG * M; + let base = plane_base + PG * idx; + let q0 = active_sums_old[base + 0u]; + let q1 = active_sums_old[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn load_active_y(idx: u32, M: u32) -> BigInt { + let plane_base = 1u * PG * M; + let base = plane_base + PG * idx; + let q0 = active_sums_old[base + 0u]; + let q1 = active_sums_old[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_active_new(plane: u32, idx: u32, M: u32, val: ptr) { + let plane_base = plane * PG * M; + let base = plane_base + PG * idx; + let w = pack_limbs_to_256(val); + active_sums_new[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + active_sums_new[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let M_old = params.y; + let M_new = params.z; + let t = gid.x; + if (t >= T) { return; } + + let chunk_base = 2u * S * t; + + // Forward: compute S dx values and accumulate prefix product. + // Read pair indices from chunk_plan, load .x for each operand, compute dx. + var pref: array; + var acc: BigInt = get_r(); + for (var k: u32 = 0u; k < S; k = k + 1u) { + let idx_l = chunk_plan[chunk_base + 2u * k + 0u]; + let idx_r = chunk_plan[chunk_base + 2u * k + 1u]; + var p_lx: BigInt = load_active_x(idx_l, M_old); + var p_rx: BigInt = load_active_x(idx_r, M_old); + var dx: BigInt = fr_sub(&p_rx, &p_lx); + if (k == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[k] = acc; + } + + // Single inversion per chunk. + var inv: BigInt = fr_inv_by_a(acc); + + // Backward peel: emit S pair sums, scatter to active_sums_new. + for (var jj: u32 = 0u; jj < S; jj = jj + 1u) { + let k = S - 1u - jj; + let idx_l = chunk_plan[chunk_base + 2u * k + 0u]; + let idx_r = chunk_plan[chunk_base + 2u * k + 1u]; + + var p_lx: BigInt = load_active_x(idx_l, M_old); + var p_ly: BigInt = load_active_y(idx_l, M_old); + var p_rx: BigInt = load_active_x(idx_r, M_old); + var p_ry: BigInt = load_active_y(idx_r, M_old); + + var inv_dx: BigInt; + if (k == 0u) { + inv_dx = inv; + } else { + var pp = pref[k - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda: BigInt = fr_sub(&p_ry, &p_ly); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_lx); + r_x = fr_sub(&r_x, &p_rx); + var r_y: BigInt = fr_sub(&p_lx, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &p_ly); + + let dst_idx = scatter_plan[t * S + k]; + store_active_new(0u, dst_idx, M_new, &r_x); + store_active_new(1u, dst_idx, M_new, &r_y); + + if (k > 0u) { + var dx_back: BigInt = fr_sub(&p_rx, &p_lx); + inv = montgomery_product(&inv, &dx_back); + } + } + + {{{ recompile }}} +} +`; + export const ba_marshal_chain_bench = `{{> structs }} // Marshal kernel for the bench-msm-chain pipeline. Transposes a CSR @@ -1958,6 +2117,101 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_planner_bench = `{{> structs }} + +// GPU-side bin-packing planner for the v3 MSM bucket-accumulate +// pipeline. One thread per bucket; uses atomicAdd to reserve global +// per-pair slots in chunk_plan / scatter_plan and per-carry slots in +// carry_plan, then writes that bucket's entries. +// +// Inputs (per current level): +// counts: array per-bucket active count +// offsets: array per-bucket starting index in active_sums_old +// +// Outputs (filled in by this kernel for the current level): +// chunk_plan: array 2 u32 per (chunk_id, slot) — pair operand indices +// scatter_plan: array 1 u32 per (chunk_id, slot) — destination in active_sums_new +// carry_plan: array 2 u32 per carry slot — (src in old, dst in new) +// totals: array> [0]=total pairs, [1]=total carries, [2]=total new actives +// new_counts: array per-bucket new active count (for next level) +// new_offsets: array per-bucket new offset in active_sums_new (for next level) +// +// Convention: discard slot = M_new - 1 (the highest index in +// active_sums_new). Pad pair source indices = (pad_l_idx, pad_r_idx) +// supplied via params. All non-real chunk_plan / scatter_plan slots +// must be pre-padded to (pad_l_idx, pad_r_idx) and discard_idx by the +// host before each planner dispatch. +// +// params.x = B (bucket count) +// params.y = S (chunk size, slots per chunk) +// (pad_l_idx / pad_r_idx / discard_idx live in the pre-padded +// arrays, not in params) + +const S: u32 = {{ s }}u; + +@group(0) @binding(0) var counts: array; +@group(0) @binding(1) var offsets: array; +@group(0) @binding(2) var chunk_plan: array; +@group(0) @binding(3) var scatter_plan: array; +@group(0) @binding(4) var carry_plan: array; +@group(0) @binding(5) var totals: array>; +@group(0) @binding(6) var new_counts: array; +@group(0) @binding(7) var new_offsets: array; +@group(0) @binding(8) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let B = params.x; + let b = gid.x; + if (b >= B) { return; } + + let n = counts[b]; + let pair_count = n / 2u; + let carry_flag = n & 1u; + let nc = pair_count + carry_flag; + new_counts[b] = nc; + + // Atomic offset reservation. Each bucket gets a unique non-overlapping + // range in the global arrays. Atomic order is non-deterministic but + // that's fine: bucket b records its assigned offsets and uses them + // consistently for its own chunk_plan / scatter_plan / new_offsets + // writes. Different buckets land in different ranges by construction. + let my_pair_off = atomicAdd(&totals[0u], pair_count); + let my_carry_off = atomicAdd(&totals[1u], carry_flag); + let my_new_off = atomicAdd(&totals[2u], nc); + new_offsets[b] = my_new_off; + + let bucket_base = offsets[b]; + + // Write this bucket's pair entries into chunk_plan / scatter_plan. + // Loop bounded by pair_count (variable per bucket; typically ~16 + // for Poisson(λ=32)). The TAIL_CAP-style compile-time bound used + // by ba_tail_reduce isn't strictly needed here since this kernel + // doesn't do field arithmetic; the loop is plain integer writes. + // We still bound it by a compile-time constant for WGSL static + // analysis purposes. + let PAIR_CAP: u32 = {{ pair_cap }}u; + for (var j: u32 = 0u; j < PAIR_CAP; j = j + 1u) { + if (j >= pair_count) { break; } + let global_slot = my_pair_off + j; + let chunk_id = global_slot / S; + let slot_in_chunk = global_slot % S; + let cp_base = 2u * (chunk_id * S + slot_in_chunk); + chunk_plan[cp_base + 0u] = bucket_base + 2u * j; + chunk_plan[cp_base + 1u] = bucket_base + 2u * j + 1u; + scatter_plan[chunk_id * S + slot_in_chunk] = my_new_off + j; + } + + if (carry_flag != 0u) { + let cs = my_carry_off; + carry_plan[2u * cs + 0u] = bucket_base + n - 1u; + carry_plan[2u * cs + 1u] = my_new_off + pair_count; + } + + {{{ recompile }}} +} +`; + export const ba_rev_packed_carry_bench = `{{> structs }} {{> bigint_funcs }} {{> montgomery_product_funcs }} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_fused_super_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_fused_super_bench.template.wgsl new file mode 100644 index 000000000000..6bb5e6d964d7 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_fused_super_bench.template.wgsl @@ -0,0 +1,157 @@ +{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Fused super-kernel for the bin-packed pair-tree MSM bucket-accumulate. +// +// Combines marshal + disjoint + scatter into one kernel. Each thread t +// handles one chunk of S pairs: +// 1. Read 2*S source indices from chunk_plan (idx_l, idx_r per slot). +// 2. Read S destination indices from scatter_plan. +// 3. Load S pair-x values from active_sums_old, compute S dx values +// and forward prefix product, all in registers. +// 4. Single fr_inv_by_a on the prefix product. +// 5. Backward peel: per slot k from S-1 down to 0: +// - load .x and .y for both operands +// - lean affine add -> R_x, R_y +// - write directly to active_sums_new at scatter_plan[t*S + k] +// - update inv for next (smaller-k) iteration +// +// vs v2 (4 kernels: marshal, disjoint, scatter, carry): the chain_buf +// and tempOut scratch buffers are eliminated. All intermediate state +// lives in registers. Per-level dispatch count drops from 4 to 2 +// (fused + carry). +// +// PARAMS: +// params.x = T_chunks (active threads, one per chunk) +// params.y = M_old (active_sums_old vec4-stride length) +// params.z = M_new (active_sums_new vec4-stride length) +// +// Layout (both active_sums buffers): 2 planes (P.x, P.y), PG=2 vec4 per +// element. plane_p flat vec4 base = p * PG * M, element e at offset +// PG * e. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var chunk_plan: array; +@group(0) @binding(1) var scatter_plan: array; +@group(0) @binding(2) var active_sums_old: array>; +@group(0) @binding(3) var active_sums_new: array>; +@group(0) @binding(4) var params: vec4; + +fn load_active_x(idx: u32, M: u32) -> BigInt { + let plane_base = 0u * PG * M; + let base = plane_base + PG * idx; + let q0 = active_sums_old[base + 0u]; + let q1 = active_sums_old[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn load_active_y(idx: u32, M: u32) -> BigInt { + let plane_base = 1u * PG * M; + let base = plane_base + PG * idx; + let q0 = active_sums_old[base + 0u]; + let q1 = active_sums_old[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_active_new(plane: u32, idx: u32, M: u32, val: ptr) { + let plane_base = plane * PG * M; + let base = plane_base + PG * idx; + let w = pack_limbs_to_256(val); + active_sums_new[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + active_sums_new[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let M_old = params.y; + let M_new = params.z; + let t = gid.x; + if (t >= T) { return; } + + let chunk_base = 2u * S * t; + + // Forward: compute S dx values and accumulate prefix product. + // Read pair indices from chunk_plan, load .x for each operand, compute dx. + var pref: array; + var acc: BigInt = get_r(); + for (var k: u32 = 0u; k < S; k = k + 1u) { + let idx_l = chunk_plan[chunk_base + 2u * k + 0u]; + let idx_r = chunk_plan[chunk_base + 2u * k + 1u]; + var p_lx: BigInt = load_active_x(idx_l, M_old); + var p_rx: BigInt = load_active_x(idx_r, M_old); + var dx: BigInt = fr_sub(&p_rx, &p_lx); + if (k == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[k] = acc; + } + + // Single inversion per chunk. + var inv: BigInt = fr_inv_by_a(acc); + + // Backward peel: emit S pair sums, scatter to active_sums_new. + for (var jj: u32 = 0u; jj < S; jj = jj + 1u) { + let k = S - 1u - jj; + let idx_l = chunk_plan[chunk_base + 2u * k + 0u]; + let idx_r = chunk_plan[chunk_base + 2u * k + 1u]; + + var p_lx: BigInt = load_active_x(idx_l, M_old); + var p_ly: BigInt = load_active_y(idx_l, M_old); + var p_rx: BigInt = load_active_x(idx_r, M_old); + var p_ry: BigInt = load_active_y(idx_r, M_old); + + var inv_dx: BigInt; + if (k == 0u) { + inv_dx = inv; + } else { + var pp = pref[k - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda: BigInt = fr_sub(&p_ry, &p_ly); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_lx); + r_x = fr_sub(&r_x, &p_rx); + var r_y: BigInt = fr_sub(&p_lx, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &p_ly); + + let dst_idx = scatter_plan[t * S + k]; + store_active_new(0u, dst_idx, M_new, &r_x); + store_active_new(1u, dst_idx, M_new, &r_y); + + if (k > 0u) { + var dx_back: BigInt = fr_sub(&p_rx, &p_lx); + inv = montgomery_product(&inv, &dx_back); + } + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_bench.template.wgsl new file mode 100644 index 000000000000..1d49e4298849 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_bench.template.wgsl @@ -0,0 +1,93 @@ +{{> structs }} + +// GPU-side bin-packing planner for the v3 MSM bucket-accumulate +// pipeline. One thread per bucket; uses atomicAdd to reserve global +// per-pair slots in chunk_plan / scatter_plan and per-carry slots in +// carry_plan, then writes that bucket's entries. +// +// Inputs (per current level): +// counts: array per-bucket active count +// offsets: array per-bucket starting index in active_sums_old +// +// Outputs (filled in by this kernel for the current level): +// chunk_plan: array 2 u32 per (chunk_id, slot) — pair operand indices +// scatter_plan: array 1 u32 per (chunk_id, slot) — destination in active_sums_new +// carry_plan: array 2 u32 per carry slot — (src in old, dst in new) +// totals: array> [0]=total pairs, [1]=total carries, [2]=total new actives +// new_counts: array per-bucket new active count (for next level) +// new_offsets: array per-bucket new offset in active_sums_new (for next level) +// +// Convention: discard slot = M_new - 1 (the highest index in +// active_sums_new). Pad pair source indices = (pad_l_idx, pad_r_idx) +// supplied via params. All non-real chunk_plan / scatter_plan slots +// must be pre-padded to (pad_l_idx, pad_r_idx) and discard_idx by the +// host before each planner dispatch. +// +// params.x = B (bucket count) +// params.y = S (chunk size, slots per chunk) +// (pad_l_idx / pad_r_idx / discard_idx live in the pre-padded +// arrays, not in params) + +const S: u32 = {{ s }}u; + +@group(0) @binding(0) var counts: array; +@group(0) @binding(1) var offsets: array; +@group(0) @binding(2) var chunk_plan: array; +@group(0) @binding(3) var scatter_plan: array; +@group(0) @binding(4) var carry_plan: array; +@group(0) @binding(5) var totals: array>; +@group(0) @binding(6) var new_counts: array; +@group(0) @binding(7) var new_offsets: array; +@group(0) @binding(8) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let B = params.x; + let b = gid.x; + if (b >= B) { return; } + + let n = counts[b]; + let pair_count = n / 2u; + let carry_flag = n & 1u; + let nc = pair_count + carry_flag; + new_counts[b] = nc; + + // Atomic offset reservation. Each bucket gets a unique non-overlapping + // range in the global arrays. Atomic order is non-deterministic but + // that's fine: bucket b records its assigned offsets and uses them + // consistently for its own chunk_plan / scatter_plan / new_offsets + // writes. Different buckets land in different ranges by construction. + let my_pair_off = atomicAdd(&totals[0u], pair_count); + let my_carry_off = atomicAdd(&totals[1u], carry_flag); + let my_new_off = atomicAdd(&totals[2u], nc); + new_offsets[b] = my_new_off; + + let bucket_base = offsets[b]; + + // Write this bucket's pair entries into chunk_plan / scatter_plan. + // Loop bounded by pair_count (variable per bucket; typically ~16 + // for Poisson(λ=32)). The TAIL_CAP-style compile-time bound used + // by ba_tail_reduce isn't strictly needed here since this kernel + // doesn't do field arithmetic; the loop is plain integer writes. + // We still bound it by a compile-time constant for WGSL static + // analysis purposes. + let PAIR_CAP: u32 = {{ pair_cap }}u; + for (var j: u32 = 0u; j < PAIR_CAP; j = j + 1u) { + if (j >= pair_count) { break; } + let global_slot = my_pair_off + j; + let chunk_id = global_slot / S; + let slot_in_chunk = global_slot % S; + let cp_base = 2u * (chunk_id * S + slot_in_chunk); + chunk_plan[cp_base + 0u] = bucket_base + 2u * j; + chunk_plan[cp_base + 1u] = bucket_base + 2u * j + 1u; + scatter_plan[chunk_id * S + slot_in_chunk] = my_new_off + j; + } + + if (carry_flag != 0u) { + let cs = my_carry_off; + carry_plan[2u * cs + 0u] = bucket_base + n - 1u; + carry_plan[2u * cs + 1u] = my_new_off + pair_count; + } + + {{{ recompile }}} +} From 02ee6342008a5e95acd0c9d2810bff6d8b6219a2 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 11:50:51 +0000 Subject: [PATCH 19/33] perf(bb/msm/v3): right-size fused + carry dispatches per level Previously dispatched MAX_CHUNKS=4160 fused threads per level. At L5 only ~117 are real, so 3979 pad-chunks each ran a full fr_inv_by_a + S mont muls = ~16ms of pure pad-chunk waste, dominating the 30ms wall and erasing the fusion win. Host simulates the bin-packing iteration upfront (cheap O(B*LEVELS) work on counts, no plan-content involvement) to determine T_chunks and T_carries per level. The GPU planner still writes the actual plan-buffer content via atomicAdd; the host just knows the dispatch shapes. --- .../ts/dev/msm-webgpu/bench-msm-tree-v3.ts | 59 ++++++++++++++----- 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.ts index ea778518e1e2..4b420b00ee7d 100644 --- a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.ts +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.ts @@ -222,11 +222,43 @@ async function runPipeline(device: GPUDevice, sm: ShaderManager, R: bigint, p: b log('info', `bucket counts: min=${minC} max=${maxC} small(<32)=${smallC}/${BUCKETS}`); // Plan-buffer sizing — must accommodate L0 max chunks. - // At L0: total pairs <= N/2 = 65536; max chunks = ceil(65536/S) = 4096. - // At deeper levels: shrinks. So max-allocated for L0. - const MAX_CHUNKS = Math.ceil(NPTS / 2 / S) + 16; // pad + const MAX_CHUNKS = Math.ceil(NPTS / 2 / S) + 16; const MAX_PAIR_SLOTS = MAX_CHUNKS * S; - const MAX_CARRIES = BUCKETS; // at most one carry per bucket + const MAX_CARRIES = BUCKETS; + + // Host simulates the bin-packing iteration to compute the right + // dispatch size per level. The GPU planner does the same work to + // fill plan buffers; this host loop only computes sizes (T_chunks, + // T_carries) so the host can dispatch the fused + carry kernels at + // the correct size per level, avoiding pad-chunk waste. + // + // Without this, the fused kernel runs MAX_CHUNKS=4160 threads per + // level regardless of actual work. At L5 with only ~117 real chunks + // that's ~3979 pad-chunks each running a full fr_inv_by_a + S mont + // muls -- dominates wall time. + const perLevelTChunks: number[] = []; + const perLevelTCarries: number[] = []; + { + let cur = new Uint32Array(initCounts); + for (let lv = 0; lv < LEVELS; lv++) { + let totalPairs = 0; + let totalCarries = 0; + const next = new Uint32Array(cur.length); + for (let b = 0; b < cur.length; b++) { + const n = cur[b]; + const p = (n / 2) | 0; + const c = n & 1; + totalPairs += p; + totalCarries += c; + next[b] = p + c; + } + perLevelTChunks.push(Math.max(1, Math.ceil(totalPairs / S))); + perLevelTCarries.push(Math.max(1, totalCarries)); + cur = next; + } + log('info', `host-simulated per-level T_chunks: ${perLevelTChunks.join(', ')}`); + log('info', `host-simulated per-level T_carries: ${perLevelTCarries.join(', ')}`); + } const mkStorage = (bytes: number, copyDst = true, copySrc = false): GPUBuffer => { let usage = GPUBufferUsage.STORAGE; @@ -344,17 +376,14 @@ async function runPipeline(device: GPUDevice, sm: ShaderManager, R: bigint, p: b let curActiveOut: GPUBuffer = bufB; const numWgsPlanner = Math.ceil(BUCKETS / WGI); - const numWgsFused = Math.ceil(MAX_CHUNKS / WGI); - const numWgsCarry = Math.ceil(MAX_CARRIES / WGI); - log('info', `dispatch sizes: planner=${numWgsPlanner} fused=${numWgsFused} carry=${numWgsCarry}`); + const numWgsFusedPerLevel = perLevelTChunks.map(t => Math.ceil(t / WGI)); + const numWgsCarryPerLevel = perLevelTCarries.map(t => Math.ceil(t / WGI)); + log('info', `numWgs: planner=${numWgsPlanner}, fused=${numWgsFusedPerLevel.join(',')}, carry=${numWgsCarryPerLevel.join(',')}`); - // Pre-write per-level params (since they depend on iteration index). + // Per-level params with the right-sized T from the host bin-pack simulator. for (let lv = 0; lv < LEVELS; lv++) { - // params.x = T_fused = MAX_CHUNKS (over-provisioned; planner-written - // chunk_plan/scatter_plan trailers point to pad => safe early-out - // is implicit because pads do harmless add+discard) - device.queue.writeBuffer(fusedParams[lv], 0, new Uint32Array([MAX_CHUNKS, M, M, 0])); - device.queue.writeBuffer(carryParams[lv], 0, new Uint32Array([MAX_CARRIES, M, M, 0])); + device.queue.writeBuffer(fusedParams[lv], 0, new Uint32Array([perLevelTChunks[lv], M, M, 0])); + device.queue.writeBuffer(carryParams[lv], 0, new Uint32Array([perLevelTCarries[lv], M, M, 0])); } // Pre-pad plan buffers (only need to do once; planner overwrites real @@ -421,14 +450,14 @@ async function runPipeline(device: GPUDevice, sm: ShaderManager, R: bigint, p: b const pass = enc.beginComputePass(); pass.setPipeline(fusedPipe); pass.setBindGroup(0, fusedBind); - pass.dispatchWorkgroups(numWgsFused, 1, 1); + pass.dispatchWorkgroups(numWgsFusedPerLevel[lv], 1, 1); pass.end(); } { const pass = enc.beginComputePass(); pass.setPipeline(carryPipe); pass.setBindGroup(0, carryBind); - pass.dispatchWorkgroups(numWgsCarry, 1, 1); + pass.dispatchWorkgroups(numWgsCarryPerLevel[lv], 1, 1); pass.end(); } From 402ebdeb622c8c3ad9495c01dcc023f55724b3c6 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 12:17:53 +0000 Subject: [PATCH 20/33] feat(bb/msm): standalone GPU bin-packing planner v2 + microbench MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Single-kernel scan + scatter planner (ba_planner_v2_bench.wgsl). One workgroup of TPB threads handles all B buckets: Phase A: per-thread local accumulate over PER_THREAD buckets (each thread reads its slice of counts[], computes pair_count + carry_flag + new_count in registers). Phase B: workgroup-wide Hillis-Steele scan over per-thread totals via shared memory — 3 scans (pair, carry, new) in one pass. Phase C: per-thread scatter — each thread walks its slice and writes chunk_plan + scatter_plan + carry_plan + new_counts + new_offsets entries with running thread-local offsets. Phase D: last thread writes totals[] (= total_pairs, total_carries, total_new_actives) for the level. No atomics, no host sync, single dispatch. Scales to B <= TPB * PER_THREAD within one workgroup (e.g. 256 * 32 = 8192 buckets). Standalone bench-planner.{ts,html} harness exercises the planner in isolation from the MSM pipeline: - Generates Poisson(λ)-distributed synthetic counts on host. - Encodes DISP back-to-back planner dispatches into ONE command encoder per timed sample (amortises submit overhead). - Reports per-planner-dispatch microseconds (min / median / max). - ?validate=1 reads outputs back and cross-checks against a host- side bin-pack reference (byte-equivalent compare). shader_manager.gen_ba_planner_v2_bench_shader(workgroup_size, per_thread, s, pair_cap). pageMap entry "bench-planner". --- .../ts/dev/msm-webgpu/bench-planner.html | 22 + .../ts/dev/msm-webgpu/bench-planner.ts | 423 ++++++++++++++++++ .../msm-webgpu/scripts/run-browserstack.mjs | 1 + .../ts/src/msm_webgpu/cuzk/shader_manager.ts | 20 + .../src/msm_webgpu/wgsl/_generated/shaders.ts | 174 ++++++- .../cuzk/ba_planner_v2_bench.template.wgsl | 170 +++++++ 6 files changed, 809 insertions(+), 1 deletion(-) create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-planner.html create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-planner.ts create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_bench.template.wgsl diff --git a/barretenberg/ts/dev/msm-webgpu/bench-planner.html b/barretenberg/ts/dev/msm-webgpu/bench-planner.html new file mode 100644 index 000000000000..dd773cd4d504 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-planner.html @@ -0,0 +1,22 @@ + + + + + Standalone GPU planner microbench (WebGPU) + + + +

Standalone GPU planner microbench

+

Query params: ?buckets=B&lambda=L&s=S&tpb=T&per=P&disp=D&reps=R&validate=1

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-planner.ts b/barretenberg/ts/dev/msm-webgpu/bench-planner.ts new file mode 100644 index 000000000000..210da3c73b83 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-planner.ts @@ -0,0 +1,423 @@ +/// +// Standalone microbench for the GPU bin-packing planner kernel. +// Isolates the planner from the rest of the MSM pipeline so we can +// pin down the minimum time required to build a per-level plan +// (chunk_plan + scatter_plan + carry_plan + new_counts/offsets) on +// the GPU. +// +// Inputs (synthetic, host-built upfront): +// counts[B] per-bucket active count drawn from Poisson(lambda). +// offsets[B+1] prefix sum of counts. +// +// Each planner dispatch is one workgroup of TPB threads. Each thread +// handles PER_THREAD buckets. B = TPB * PER_THREAD. +// +// Timing methodology: +// - Compile pipeline. +// - Warmup (1 dispatch). +// - For each rep: +// Encode DISP back-to-back dispatches in ONE command encoder. +// performance.now() right before submit and after await. +// - Per-planner time = sample / DISP. +// - Report min, median, max across reps. +// +// Validation (?validate=1): +// - Run 1 dispatch. +// - Read back chunk_plan, scatter_plan, carry_plan, new_counts, +// new_offsets, totals. +// - Cross-check against a host-side bin-pack reference. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { makeResultsClient } from './results_post.js'; + +let BUCKETS = 4096; +let LAMBDA = 32; // mean per-bucket count +let S = 16; +let TPB = 256; +let PER_THREAD = 16; // BUCKETS / TPB +let DISP = 128; +let REPS = 5; +let VALIDATE = false; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +// Approximate Poisson(λ) sample via the Knuth method. Adequate for +// generating synthetic bucket-count distributions in [0, ~3λ]. +function poisson(lambda: number, rng: () => number): number { + const L = Math.exp(-lambda); + let k = 0; + let p = 1.0; + for (;;) { + k++; + const u = (rng() >>> 0) / 0x100000000; + p *= u; + if (p <= L) break; + if (k > 200) break; + } + return k - 1; +} + +function buildSyntheticCounts(B: number, lambda: number, seed: number): { counts: Uint32Array; offsets: Uint32Array } { + const rng = makeRng(seed); + const counts = new Uint32Array(B); + for (let b = 0; b < B; b++) counts[b] = poisson(lambda, rng); + const offsets = new Uint32Array(B + 1); + for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b]; + return { counts, offsets }; +} + +// Host-side bin-pack reference. Returns the EXACT same outputs the GPU +// planner is expected to produce (modulo per-bucket atomic-ordering +// differences, which the v2 planner avoids — its order matches host). +function buildHostReference(counts: Uint32Array, offsets: Uint32Array, S: number) { + const B = counts.length; + let totalPairs = 0; + let totalCarries = 0; + let totalNew = 0; + const newCounts = new Uint32Array(B); + const newOffsets = new Uint32Array(B + 1); + // First pass: compute new_counts and accumulate totals. + for (let b = 0; b < B; b++) { + const n = counts[b]; + const pc = Math.floor(n / 2); + const cf = n & 1; + newCounts[b] = pc + cf; + totalPairs += pc; + totalCarries += cf; + totalNew += pc + cf; + } + for (let b = 0; b < B; b++) newOffsets[b + 1] = newOffsets[b] + newCounts[b]; + const numChunks = Math.max(1, Math.ceil(totalPairs / S)); + const chunkPlan = new Uint32Array(2 * numChunks * S); + const scatterPlan = new Uint32Array(numChunks * S); + const carryPlan = new Uint32Array(2 * Math.max(1, totalCarries)); + let pairOff = 0; + let carryOff = 0; + for (let b = 0; b < B; b++) { + const n = counts[b]; + const pc = Math.floor(n / 2); + const cf = n & 1; + const bucketBase = offsets[b]; + for (let j = 0; j < pc; j++) { + const slot = pairOff + j; + const chunkId = Math.floor(slot / S); + const slotInChunk = slot % S; + const cpBase = 2 * (chunkId * S + slotInChunk); + chunkPlan[cpBase + 0] = bucketBase + 2 * j; + chunkPlan[cpBase + 1] = bucketBase + 2 * j + 1; + scatterPlan[chunkId * S + slotInChunk] = newOffsets[b] + j; + } + if (cf) { + carryPlan[2 * carryOff + 0] = bucketBase + n - 1; + carryPlan[2 * carryOff + 1] = newOffsets[b] + pc; + carryOff++; + } + pairOff += pc; + } + return { chunkPlan, scatterPlan, carryPlan, newCounts, newOffsets, totalPairs, totalCarries, totalNew, numChunks }; +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +interface BenchResult { + buckets: number; + lambda: number; + s: number; + tpb: number; + per_thread: number; + disp: number; + reps: number; + total_pairs: number; + total_carries: number; + num_chunks: number; + per_dispatch_us: { min: number; median: number; max: number }; + wall_samples_ms: number[]; + validated: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: Record | null; + results: BenchResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] }; +(window as unknown as { __bench: BenchState }).__bench = benchState; +const resultsClient = makeResultsClient({ page: 'bench-planner' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, params: benchState.params, results: benchState.results, + error: benchState.error, log: benchState.log, + userAgent: navigator.userAgent, hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-planner] ${msg}`); +} + +async function compileOne(device: GPUDevice, code: string, key: string, layout: GPUBindGroupLayout): Promise { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const m of info.messages) { + const line = `[shader ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`; + if (m.type === 'error') { console.error(line); log('err', line); errLines.push(line); hasError = true; } + else { console.warn(line); } + } + if (hasError) throw new Error(`WGSL compile failed for ${key}: ${errLines.slice(0, 4).join(' | ')}`); + return device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); +} + +async function readbackU32(device: GPUDevice, buf: GPUBuffer, byteLength: number): Promise { + const staging = device.createBuffer({ size: byteLength, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, byteLength); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const out = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + return out; +} + +async function runOne(device: GPUDevice, sm: ShaderManager): Promise { + log('info', `=== B=${BUCKETS} λ=${LAMBDA} S=${S} TPB=${TPB} PER=${PER_THREAD} DISP=${DISP} REPS=${REPS}`); + if (TPB * PER_THREAD !== BUCKETS) throw new Error(`BUCKETS=${BUCKETS} must equal TPB*PER_THREAD=${TPB * PER_THREAD}`); + + const { counts, offsets } = buildSyntheticCounts(BUCKETS, LAMBDA, 0x5fa11); + let totalActive = 0; + let cMin = 99999, cMax = 0; + for (let b = 0; b < BUCKETS; b++) { + totalActive += counts[b]; + if (counts[b] > cMax) cMax = counts[b]; + if (counts[b] < cMin) cMin = counts[b]; + } + log('info', `synthetic counts: min=${cMin} max=${cMax} totalActive=${totalActive}`); + + const ref = buildHostReference(counts, offsets, S); + log('info', `host reference: totalPairs=${ref.totalPairs} totalCarries=${ref.totalCarries} numChunks=${ref.numChunks}`); + + // Allocate output buffers sized for the host-computed plan. + // For a real MSM these sizes would be conservative max bounds; here + // we use exact host values for tighter validation. + const mkStorage = (bytes: number, copySrc = false, copyDst = false): GPUBuffer => { + let usage = GPUBufferUsage.STORAGE; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + return device.createBuffer({ size: bytes, usage }); + }; + + const countsBuf = mkStorage(counts.byteLength, false, true); + const offsetsBuf = mkStorage(offsets.byteLength, false, true); + device.queue.writeBuffer(countsBuf, 0, counts); + device.queue.writeBuffer(offsetsBuf, 0, offsets); + + const chunkPlanBytes = ref.chunkPlan.byteLength; + const scatterPlanBytes = ref.scatterPlan.byteLength; + const carryPlanBytes = ref.carryPlan.byteLength; + const newCountsBytes = ref.newCounts.byteLength; + const newOffsetsBytes = ref.newOffsets.byteLength; + const totalsBytes = 16; + + const chunkPlanBuf = mkStorage(chunkPlanBytes, true); + const scatterPlanBuf = mkStorage(scatterPlanBytes, true); + const carryPlanBuf = mkStorage(carryPlanBytes, true); + const newCountsBuf = mkStorage(newCountsBytes, true); + const newOffsetsBuf = mkStorage(newOffsetsBytes, true); + const totalsBuf = mkStorage(totalsBytes, true); + + const paramsBuf = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(paramsBuf, 0, new Uint32Array([BUCKETS, S, 0, 0])); + + const layout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 5, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 6, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 7, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 8, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const pipeline = await compileOne(device, sm.gen_ba_planner_v2_bench_shader(TPB, PER_THREAD, S, 64), `planner-v2-T${TPB}-P${PER_THREAD}-S${S}`, layout); + const bind = device.createBindGroup({ + layout, + entries: [ + { binding: 0, resource: { buffer: countsBuf } }, + { binding: 1, resource: { buffer: offsetsBuf } }, + { binding: 2, resource: { buffer: chunkPlanBuf } }, + { binding: 3, resource: { buffer: scatterPlanBuf } }, + { binding: 4, resource: { buffer: carryPlanBuf } }, + { binding: 5, resource: { buffer: newCountsBuf } }, + { binding: 6, resource: { buffer: newOffsetsBuf } }, + { binding: 7, resource: { buffer: totalsBuf } }, + { binding: 8, resource: { buffer: paramsBuf } }, + ], + }); + + // Warmup. + { + const enc = device.createCommandEncoder(); + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(1, 1, 1); + pass.end(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + log('info', 'warmup done'); + + // Validation (one dispatch, read back, compare against host reference). + let validated = false; + if (VALIDATE) { + const gpuTotals = await readbackU32(device, totalsBuf, totalsBytes); + const gpuNewCounts = await readbackU32(device, newCountsBuf, newCountsBytes); + const gpuNewOffsets = await readbackU32(device, newOffsetsBuf, newOffsetsBytes); + const gpuChunkPlan = await readbackU32(device, chunkPlanBuf, chunkPlanBytes); + const gpuScatterPlan = await readbackU32(device, scatterPlanBuf, scatterPlanBytes); + const gpuCarryPlan = await readbackU32(device, carryPlanBuf, carryPlanBytes); + + const mismatches: string[] = []; + if (gpuTotals[0] !== ref.totalPairs) mismatches.push(`totals[0]: gpu=${gpuTotals[0]} ref=${ref.totalPairs}`); + if (gpuTotals[1] !== ref.totalCarries) mismatches.push(`totals[1]: gpu=${gpuTotals[1]} ref=${ref.totalCarries}`); + if (gpuTotals[2] !== ref.totalNew) mismatches.push(`totals[2]: gpu=${gpuTotals[2]} ref=${ref.totalNew}`); + for (let b = 0; b < BUCKETS && mismatches.length < 8; b++) { + if (gpuNewCounts[b] !== ref.newCounts[b]) mismatches.push(`newCounts[${b}]: gpu=${gpuNewCounts[b]} ref=${ref.newCounts[b]}`); + if (gpuNewOffsets[b] !== ref.newOffsets[b]) mismatches.push(`newOffsets[${b}]: gpu=${gpuNewOffsets[b]} ref=${ref.newOffsets[b]}`); + } + // chunk_plan/scatter_plan: compare element-wise. + let cpFails = 0; + for (let i = 0; i < ref.chunkPlan.length; i++) { + if (gpuChunkPlan[i] !== ref.chunkPlan[i]) { cpFails++; if (cpFails <= 3) mismatches.push(`chunkPlan[${i}]: gpu=${gpuChunkPlan[i]} ref=${ref.chunkPlan[i]}`); } + } + let spFails = 0; + for (let i = 0; i < ref.scatterPlan.length; i++) { + if (gpuScatterPlan[i] !== ref.scatterPlan[i]) { spFails++; if (spFails <= 3) mismatches.push(`scatterPlan[${i}]: gpu=${gpuScatterPlan[i]} ref=${ref.scatterPlan[i]}`); } + } + let cyFails = 0; + for (let i = 0; i < 2 * ref.totalCarries; i++) { + if (gpuCarryPlan[i] !== ref.carryPlan[i]) { cyFails++; if (cyFails <= 3) mismatches.push(`carryPlan[${i}]: gpu=${gpuCarryPlan[i]} ref=${ref.carryPlan[i]}`); } + } + if (mismatches.length === 0 && cpFails === 0 && spFails === 0 && cyFails === 0) { + validated = true; + log('ok', 'validation: PASS — GPU planner output byte-equivalent to host reference'); + } else { + log('err', `validation: FAIL — ${cpFails} chunkPlan, ${spFails} scatterPlan, ${cyFails} carryPlan mismatches; first few:`); + for (const m of mismatches.slice(0, 10)) log('err', ` ${m}`); + } + } + + // Timed runs: DISP back-to-back planner dispatches in one command encoder. + const samples: number[] = []; + for (let r = 0; r < REPS; r++) { + const enc = device.createCommandEncoder(); + for (let d = 0; d < DISP; d++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(1, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + samples.push(performance.now() - t0); + } + const med = median(samples); + const mn = Math.min(...samples); + const mx = Math.max(...samples); + const perDispatchMin = (mn / DISP) * 1000; + const perDispatchMed = (med / DISP) * 1000; + const perDispatchMax = (mx / DISP) * 1000; + + log( + 'ok', + `per-planner: min=${perDispatchMin.toFixed(2)}μs median=${perDispatchMed.toFixed(2)}μs max=${perDispatchMax.toFixed(2)}μs` + + ` (total wall: min=${mn.toFixed(2)}ms median=${med.toFixed(2)}ms max=${mx.toFixed(2)}ms over DISP=${DISP})`, + ); + + countsBuf.destroy(); offsetsBuf.destroy(); chunkPlanBuf.destroy(); scatterPlanBuf.destroy(); + carryPlanBuf.destroy(); newCountsBuf.destroy(); newOffsetsBuf.destroy(); totalsBuf.destroy(); + paramsBuf.destroy(); + + return { + buckets: BUCKETS, lambda: LAMBDA, s: S, tpb: TPB, per_thread: PER_THREAD, disp: DISP, reps: REPS, + total_pairs: ref.totalPairs, total_carries: ref.totalCarries, num_chunks: ref.numChunks, + per_dispatch_us: { min: perDispatchMin, median: perDispatchMed, max: perDispatchMax }, + wall_samples_ms: samples, + validated, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + if (qp.get('buckets')) BUCKETS = parseInt(qp.get('buckets')!, 10); + if (qp.get('lambda')) LAMBDA = parseInt(qp.get('lambda')!, 10); + if (qp.get('s')) S = parseInt(qp.get('s')!, 10); + if (qp.get('tpb')) TPB = parseInt(qp.get('tpb')!, 10); + if (qp.get('per')) PER_THREAD = parseInt(qp.get('per')!, 10); + if (qp.get('disp')) DISP = parseInt(qp.get('disp')!, 10); + if (qp.get('reps')) REPS = parseInt(qp.get('reps')!, 10); + if (qp.get('validate') === '1') VALIDATE = true; + return { buckets: BUCKETS, lambda: LAMBDA, s: S, tpb: TPB, per: PER_THREAD, disp: DISP, reps: REPS, validate: VALIDATE }; +} + +async function main() { + try { + if (!('gpu' in navigator)) throw new Error('navigator.gpu missing'); + const params = parseParams(); + benchState.params = params; + log('info', `params: ${JSON.stringify(params)}`); + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + const sm = new ShaderManager(4, BUCKETS, BN254_CURVE_CONFIG, false); + const r = await runOne(device, sm); + benchState.results.push(r); + resultsClient.postProgress({ kind: 'planner_done', per_dispatch_us: r.per_dispatch_us, validated: r.validated }); + benchState.state = 'done'; + log('ok', 'done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { const msg = e instanceof Error ? e.message : String(e); log('err', `unhandled: ${msg}`); benchState.state = 'error'; benchState.error = msg; }) + .finally(() => { postFinal().catch(() => {}); }); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs index b24956e6ca9d..44247b10b8ab 100644 --- a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs +++ b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs @@ -134,6 +134,7 @@ const pageMap = { "bench-msm-tree": "/dev/msm-webgpu/bench-msm-tree.html", "bench-msm-tree-v2": "/dev/msm-webgpu/bench-msm-tree-v2.html", "bench-msm-tree-v3": "/dev/msm-webgpu/bench-msm-tree-v3.html", + "bench-planner": "/dev/msm-webgpu/bench-planner.html", "bench-smvp-tree": "/dev/msm-webgpu/bench-smvp-tree.html", sanity: "/dev/msm-webgpu/index.html", }; diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index b30b52bf3a48..24fe0121c2f4 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -13,6 +13,7 @@ import { ba_pair_disjoint_bench as ba_pair_disjoint_bench_shader, ba_pair_disjoint_tree_bench as ba_pair_disjoint_tree_bench_shader, ba_planner_bench as ba_planner_bench_shader, + ba_planner_v2_bench as ba_planner_v2_bench_shader, ba_scatter_pairs_bench as ba_scatter_pairs_bench_shader, ba_tail_reduce_bench as ba_tail_reduce_bench_shader, ba_rev_packed_carry_bench as ba_rev_packed_carry_bench_shader, @@ -822,6 +823,25 @@ ${packLines.join('\n')} ); } + /** + * v2 GPU planner: single-kernel scan + scatter. One workgroup of TPB + * threads handles all B buckets via per-thread local scan + workgroup- + * wide Hillis-Steele scan + per-thread scatter. No atomics, no host + * sync, single dispatch. Scales to B <= TPB * PER_THREAD within one + * workgroup (e.g. 256 * 32 = 8192 buckets). + */ + public gen_ba_planner_v2_bench_shader(workgroup_size: number, per_thread: number, s: number, pair_cap: number = 64): string { + if (workgroup_size <= 0 || per_thread <= 0 || s <= 0 || pair_cap <= 0 || + !Number.isInteger(workgroup_size) || !Number.isInteger(per_thread) || !Number.isInteger(s) || !Number.isInteger(pair_cap)) { + throw new Error(`gen_ba_planner_v2_bench_shader: positive integer args required`); + } + return mustache.render( + ba_planner_v2_bench_shader, + { workgroup_size, per_thread, pair_cap, s, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + /** * GPU-side bin-packing planner for the v3 pipeline. One thread per * bucket; atomicAdd reserves global per-pair / per-carry / per-new- diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index 0442dd802a74..232ae7a49364 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1,6 +1,6 @@ // AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT. // Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate. -// 61 shader sources inlined. +// 62 shader sources inlined. /* eslint-disable */ @@ -2212,6 +2212,178 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_planner_v2_bench = `{{> structs }} + +// Optimal single-kernel GPU bin-packing planner for the MSM +// bucket-accumulate pair-tree. +// +// One workgroup of TPB threads processes B buckets. Each thread +// handles PER_THREAD = B / TPB buckets via a contiguous slice +// [tid * PER_THREAD, (tid+1) * PER_THREAD). +// +// Phase A — Per-thread local scan +// For each of its PER_THREAD buckets, compute (pair_count, carry_flag, +// new_count). Accumulate per-thread totals (sum across the thread's +// slice). Keep the per-bucket triples in registers; we will re-scan +// them in Phase B. +// +// Phase B — Workgroup-wide Hillis-Steele scan (3 in parallel) +// Scan the per-thread totals for pair, carry, new across the TPB +// threads in shared memory. Result: each thread gets the global +// prefix sum at the START of its slice (= base offset for its first +// bucket). +// +// Phase C — Per-thread scatter +// For each bucket in the thread's slice (in order), use the running +// thread-local offset to compute global pair_offset_b and write the +// pair_count[b] chunk_plan entries plus the (optional) carry_plan +// entry. Update local running offsets. Write new_counts[b] and +// new_offsets[b] for the next level. +// +// Phase D — One thread writes totals. +// totals[0] = total_pairs, totals[1] = total_carries, +// totals[2] = total_new_actives. +// +// Single dispatch. No atomics. No host sync. Scales to B = TPB * +// PER_THREAD (e.g. 256 * 32 = 8192) within one workgroup. Larger B +// requires multi-workgroup scan + global combine (out of scope here). +// +// Compile-time constants: +// TPB : workgroup size (e.g. 256) +// PER_THREAD : buckets per thread (e.g. 16 for B=4096, 32 for B=8192) +// PAIR_CAP : bound on per-bucket pair count (Poisson(λ=32) tail +// is ~30; choose 64 for safety) +// S : chunk size in pairs (e.g. 16) + +const TPB: u32 = {{ workgroup_size }}u; +const PER_THREAD: u32 = {{ per_thread }}u; +const PAIR_CAP: u32 = {{ pair_cap }}u; +const S: u32 = {{ s }}u; + +@group(0) @binding(0) var counts: array; +@group(0) @binding(1) var offsets: array; +@group(0) @binding(2) var chunk_plan: array; +@group(0) @binding(3) var scatter_plan: array; +@group(0) @binding(4) var carry_plan: array; +@group(0) @binding(5) var new_counts: array; +@group(0) @binding(6) var new_offsets: array; +@group(0) @binding(7) var totals: array; +@group(0) @binding(8) var params: vec4; +// params.x = B + +// Workgroup-shared running prefixes for the 3 scans. +var pair_scan: array; +var carry_scan: array; +var new_scan: array; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; + let B = params.x; + + // Phase A: per-thread local read + accumulate. + // Keep PER_THREAD bucket triples in registers (small array). + var local_pc: array; + var local_cf: array; + var local_nc: array; + var sum_p: u32 = 0u; + var sum_c: u32 = 0u; + var sum_n: u32 = 0u; + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tid * PER_THREAD + k; + var pc: u32 = 0u; + var cf: u32 = 0u; + var nc: u32 = 0u; + if (b < B) { + let n = counts[b]; + pc = n / 2u; + cf = n & 1u; + nc = pc + cf; + } + local_pc[k] = pc; + local_cf[k] = cf; + local_nc[k] = nc; + sum_p += pc; + sum_c += cf; + sum_n += nc; + } + + // Phase B: workgroup-wide Hillis-Steele inclusive scan over per- + // thread totals (3 scans interleaved). + pair_scan[tid] = sum_p; + carry_scan[tid] = sum_c; + new_scan[tid] = sum_n; + workgroupBarrier(); + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var add_p: u32 = 0u; + var add_c: u32 = 0u; + var add_n: u32 = 0u; + if (tid >= stride) { + add_p = pair_scan[tid - stride]; + add_c = carry_scan[tid - stride]; + add_n = new_scan[tid - stride]; + } + workgroupBarrier(); + if (tid >= stride) { + pair_scan[tid] = pair_scan[tid] + add_p; + carry_scan[tid] = carry_scan[tid] + add_c; + new_scan[tid] = new_scan[tid] + add_n; + } + workgroupBarrier(); + } + // pair_scan[tid] is now inclusive prefix. Exclusive base = inclusive - own_sum. + var local_pair_off: u32 = pair_scan[tid] - sum_p; + var local_carry_off: u32 = carry_scan[tid] - sum_c; + var local_new_off: u32 = new_scan[tid] - sum_n; + + // Phase D: thread 0 writes totals (using the FINAL inclusive scan). + if (tid == TPB - 1u) { + totals[0] = pair_scan[tid]; + totals[1] = carry_scan[tid]; + totals[2] = new_scan[tid]; + } + + // Phase C: per-thread scatter. + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tid * PER_THREAD + k; + if (b >= B) { break; } + + let pc = local_pc[k]; + let cf = local_cf[k]; + let nc = local_nc[k]; + new_counts[b] = nc; + new_offsets[b] = local_new_off; + + let bucket_base = offsets[b]; + + // Pair entries: bounded loop, break at pc. + for (var j: u32 = 0u; j < PAIR_CAP; j = j + 1u) { + if (j >= pc) { break; } + let global_slot = local_pair_off + j; + let chunk_id = global_slot / S; + let slot_in_chunk = global_slot % S; + let cp_base = 2u * (chunk_id * S + slot_in_chunk); + chunk_plan[cp_base + 0u] = bucket_base + 2u * j; + chunk_plan[cp_base + 1u] = bucket_base + 2u * j + 1u; + scatter_plan[chunk_id * S + slot_in_chunk] = local_new_off + j; + } + + // Carry entry (if odd count). + if (cf != 0u) { + let cs = local_carry_off; + carry_plan[2u * cs + 0u] = bucket_base + counts[b] - 1u; + carry_plan[2u * cs + 1u] = local_new_off + pc; + } + + local_pair_off += pc; + local_carry_off += cf; + local_new_off += nc; + } + + {{{ recompile }}} +} +`; + export const ba_rev_packed_carry_bench = `{{> structs }} {{> bigint_funcs }} {{> montgomery_product_funcs }} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_bench.template.wgsl new file mode 100644 index 000000000000..789747d86f82 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_bench.template.wgsl @@ -0,0 +1,170 @@ +{{> structs }} + +// Optimal single-kernel GPU bin-packing planner for the MSM +// bucket-accumulate pair-tree. +// +// One workgroup of TPB threads processes B buckets. Each thread +// handles PER_THREAD = B / TPB buckets via a contiguous slice +// [tid * PER_THREAD, (tid+1) * PER_THREAD). +// +// Phase A — Per-thread local scan +// For each of its PER_THREAD buckets, compute (pair_count, carry_flag, +// new_count). Accumulate per-thread totals (sum across the thread's +// slice). Keep the per-bucket triples in registers; we will re-scan +// them in Phase B. +// +// Phase B — Workgroup-wide Hillis-Steele scan (3 in parallel) +// Scan the per-thread totals for pair, carry, new across the TPB +// threads in shared memory. Result: each thread gets the global +// prefix sum at the START of its slice (= base offset for its first +// bucket). +// +// Phase C — Per-thread scatter +// For each bucket in the thread's slice (in order), use the running +// thread-local offset to compute global pair_offset_b and write the +// pair_count[b] chunk_plan entries plus the (optional) carry_plan +// entry. Update local running offsets. Write new_counts[b] and +// new_offsets[b] for the next level. +// +// Phase D — One thread writes totals. +// totals[0] = total_pairs, totals[1] = total_carries, +// totals[2] = total_new_actives. +// +// Single dispatch. No atomics. No host sync. Scales to B = TPB * +// PER_THREAD (e.g. 256 * 32 = 8192) within one workgroup. Larger B +// requires multi-workgroup scan + global combine (out of scope here). +// +// Compile-time constants: +// TPB : workgroup size (e.g. 256) +// PER_THREAD : buckets per thread (e.g. 16 for B=4096, 32 for B=8192) +// PAIR_CAP : bound on per-bucket pair count (Poisson(λ=32) tail +// is ~30; choose 64 for safety) +// S : chunk size in pairs (e.g. 16) + +const TPB: u32 = {{ workgroup_size }}u; +const PER_THREAD: u32 = {{ per_thread }}u; +const PAIR_CAP: u32 = {{ pair_cap }}u; +const S: u32 = {{ s }}u; + +@group(0) @binding(0) var counts: array; +@group(0) @binding(1) var offsets: array; +@group(0) @binding(2) var chunk_plan: array; +@group(0) @binding(3) var scatter_plan: array; +@group(0) @binding(4) var carry_plan: array; +@group(0) @binding(5) var new_counts: array; +@group(0) @binding(6) var new_offsets: array; +@group(0) @binding(7) var totals: array; +@group(0) @binding(8) var params: vec4; +// params.x = B + +// Workgroup-shared running prefixes for the 3 scans. +var pair_scan: array; +var carry_scan: array; +var new_scan: array; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; + let B = params.x; + + // Phase A: per-thread local read + accumulate. + // Keep PER_THREAD bucket triples in registers (small array). + var local_pc: array; + var local_cf: array; + var local_nc: array; + var sum_p: u32 = 0u; + var sum_c: u32 = 0u; + var sum_n: u32 = 0u; + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tid * PER_THREAD + k; + var pc: u32 = 0u; + var cf: u32 = 0u; + var nc: u32 = 0u; + if (b < B) { + let n = counts[b]; + pc = n / 2u; + cf = n & 1u; + nc = pc + cf; + } + local_pc[k] = pc; + local_cf[k] = cf; + local_nc[k] = nc; + sum_p += pc; + sum_c += cf; + sum_n += nc; + } + + // Phase B: workgroup-wide Hillis-Steele inclusive scan over per- + // thread totals (3 scans interleaved). + pair_scan[tid] = sum_p; + carry_scan[tid] = sum_c; + new_scan[tid] = sum_n; + workgroupBarrier(); + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var add_p: u32 = 0u; + var add_c: u32 = 0u; + var add_n: u32 = 0u; + if (tid >= stride) { + add_p = pair_scan[tid - stride]; + add_c = carry_scan[tid - stride]; + add_n = new_scan[tid - stride]; + } + workgroupBarrier(); + if (tid >= stride) { + pair_scan[tid] = pair_scan[tid] + add_p; + carry_scan[tid] = carry_scan[tid] + add_c; + new_scan[tid] = new_scan[tid] + add_n; + } + workgroupBarrier(); + } + // pair_scan[tid] is now inclusive prefix. Exclusive base = inclusive - own_sum. + var local_pair_off: u32 = pair_scan[tid] - sum_p; + var local_carry_off: u32 = carry_scan[tid] - sum_c; + var local_new_off: u32 = new_scan[tid] - sum_n; + + // Phase D: thread 0 writes totals (using the FINAL inclusive scan). + if (tid == TPB - 1u) { + totals[0] = pair_scan[tid]; + totals[1] = carry_scan[tid]; + totals[2] = new_scan[tid]; + } + + // Phase C: per-thread scatter. + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tid * PER_THREAD + k; + if (b >= B) { break; } + + let pc = local_pc[k]; + let cf = local_cf[k]; + let nc = local_nc[k]; + new_counts[b] = nc; + new_offsets[b] = local_new_off; + + let bucket_base = offsets[b]; + + // Pair entries: bounded loop, break at pc. + for (var j: u32 = 0u; j < PAIR_CAP; j = j + 1u) { + if (j >= pc) { break; } + let global_slot = local_pair_off + j; + let chunk_id = global_slot / S; + let slot_in_chunk = global_slot % S; + let cp_base = 2u * (chunk_id * S + slot_in_chunk); + chunk_plan[cp_base + 0u] = bucket_base + 2u * j; + chunk_plan[cp_base + 1u] = bucket_base + 2u * j + 1u; + scatter_plan[chunk_id * S + slot_in_chunk] = local_new_off + j; + } + + // Carry entry (if odd count). + if (cf != 0u) { + let cs = local_carry_off; + carry_plan[2u * cs + 0u] = bucket_base + counts[b] - 1u; + carry_plan[2u * cs + 1u] = local_new_off + pc; + } + + local_pair_off += pc; + local_carry_off += cf; + local_new_off += nc; + } + + {{{ recompile }}} +} From 3d7aace70cd084cc0fc37e515d903602d86a30e5 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 13:35:06 +0000 Subject: [PATCH 21/33] =?UTF-8?q?feat(bb/msm):=20cuZK=20CSR=20=E2=86=92=20?= =?UTF-8?q?v2=20active=5Fsums=20layout=20converter=20+=20bench?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Step 1 of the v2 pair-tree MSM integration: materialize the bucket-major active_sums / active_counts / active_offsets buffers from the cuZK transpose output (val_idx + row_ptr) and the packed 8×u32 cached bases. Two new shaders: - csr_to_v2_active_sums.template.wgsl — copies packed 8×u32 base coords from new_point_{x,y}[val_idx[slot]] to active_sums_{x,y}[slot] for every (subtask, slot) in the transpose layout. Layout-equivalent raw vec4 copy; no field-element math. Sign handling stays at finalize. - csr_to_v2_meta.template.wgsl — derives active_counts and active_offsets per (subtask, bucket) from cuZK row_ptr. Subtask- relative offsets matching what the v2 pair-tree harness expects. Both wired to shader_manager.ts via gen_csr_to_v2_active_sums_shader and gen_csr_to_v2_meta_shader. Standalone harness: - dev/msm-webgpu/bench-csr-to-v2.{ts,html} — synthetic CSR generator, full byte-equivalent validation against a host reference (active_sums word-for-word, counts/offsets per bucket), plus per-dispatch timing for active_sums, meta, and the two combined into one encoder. - scripts/run-browserstack.mjs — bench-csr-to-v2 entry in pageMap. --- .../ts/dev/msm-webgpu/bench-csr-to-v2.html | 22 + .../ts/dev/msm-webgpu/bench-csr-to-v2.ts | 531 ++++++++++++++++++ .../msm-webgpu/scripts/run-browserstack.mjs | 1 + .../ts/src/msm_webgpu/cuzk/shader_manager.ts | 35 ++ .../src/msm_webgpu/wgsl/_generated/shaders.ts | 104 +++- .../cuzk/csr_to_v2_active_sums.template.wgsl | 53 ++ .../wgsl/cuzk/csr_to_v2_meta.template.wgsl | 45 ++ 7 files changed, 790 insertions(+), 1 deletion(-) create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.html create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.ts create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_active_sums.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_meta.template.wgsl diff --git a/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.html b/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.html new file mode 100644 index 000000000000..31b300972d08 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.html @@ -0,0 +1,22 @@ + + + + + cuZK CSR -> v2 active_sums layout converter bench (WebGPU) + + + +

cuZK CSR -> v2 active_sums layout converter

+

Query params: ?subtasks=T&columns=C&input=N&wg=W&disp=D&reps=R&validate=1

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.ts b/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.ts new file mode 100644 index 000000000000..9fdb7ab4394e --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.ts @@ -0,0 +1,531 @@ +/// +// Standalone microbench + validator for the cuZK CSR -> v2 active_sums +// layout converter. The converter consumes the cuZK transpose output +// (val_idx + row_ptr) and the packed 8xu32 cached bases, and emits the +// bucket-major active_sums + per-bucket counts/offsets that the v2 +// pair-tree expects. +// +// Inputs are synthetic — random bucket assignments per (subtask, scalar) +// — so the harness validates the converter in isolation without needing +// the full cuZK convert + decompose + transpose pipeline. +// +// Validation: byte-equivalent compare of active_sums_x/y, active_counts, +// and active_offsets against a host-side reference. The reference is +// trivial: active_sums[k] = new_point[val_idx[k]], counts[b] = +// row_ptr[b+1]-row_ptr[b], offsets[b] = row_ptr[b]. +// +// Timing methodology mirrors bench-planner: warmup, then DISP back-to- +// back dispatches in one encoder, divided by DISP for per-call time. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { makeResultsClient } from './results_post.js'; + +let NUM_SUBTASKS = 16; +let NUM_COLUMNS = 4096; +let INPUT_SIZE = 1 << 14; +let WG = 64; +let DISP = 32; +let REPS = 5; +let VALIDATE = false; +let SEED = 0xc5af; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function buildSyntheticCsr( + numSubtasks: number, + numColumns: number, + inputSize: number, + seed: number, +): { rowPtr: Uint32Array; valIdx: Uint32Array } { + const rng = makeRng(seed); + const rowPtr = new Uint32Array(numSubtasks * (numColumns + 1)); + const valIdx = new Uint32Array(numSubtasks * inputSize); + + for (let s = 0; s < numSubtasks; s++) { + const bucketOf = new Uint32Array(inputSize); + const perBucketCount = new Uint32Array(numColumns); + for (let i = 0; i < inputSize; i++) { + const hi = (rng() >>> 16) & 0xffff; + const lo = (rng() >>> 16) & 0xffff; + const r = hi * 0x10000 + lo; + const b = r % numColumns; + bucketOf[i] = b; + perBucketCount[b]++; + } + const rpBase = s * (numColumns + 1); + rowPtr[rpBase] = 0; + for (let b = 0; b < numColumns; b++) { + rowPtr[rpBase + b + 1] = rowPtr[rpBase + b] + perBucketCount[b]; + } + const cursor = new Uint32Array(numColumns); + const viBase = s * inputSize; + for (let i = 0; i < inputSize; i++) { + const b = bucketOf[i]; + const slot = rowPtr[rpBase + b] + cursor[b]; + valIdx[viBase + slot] = i; + cursor[b]++; + } + } + return { rowPtr, valIdx }; +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +interface BenchResult { + num_subtasks: number; + num_columns: number; + input_size: number; + total_slots: number; + total_buckets: number; + wg: number; + disp: number; + reps: number; + active_sums_us: { min: number; median: number; max: number }; + meta_us: { min: number; median: number; max: number }; + combined_us: { min: number; median: number; max: number }; + ns_per_slot_combined: number; + validated: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: Record | null; + results: BenchResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] }; +(window as unknown as { __bench: BenchState }).__bench = benchState; +const resultsClient = makeResultsClient({ page: 'bench-csr-to-v2' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, + params: benchState.params, + results: benchState.results, + error: benchState.error, + log: benchState.log, + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-csr-to-v2] ${msg}`); +} + +async function compileOne( + device: GPUDevice, + code: string, + key: string, + layout: GPUBindGroupLayout, +): Promise { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const m of info.messages) { + const line = `[shader ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`; + if (m.type === 'error') { + console.error(line); + log('err', line); + errLines.push(line); + hasError = true; + } else { + console.warn(line); + } + } + if (hasError) throw new Error(`WGSL compile failed for ${key}: ${errLines.slice(0, 4).join(' | ')}`); + return device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); +} + +async function readbackU32(device: GPUDevice, buf: GPUBuffer, byteLength: number): Promise { + const staging = device.createBuffer({ size: byteLength, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, byteLength); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const out = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + return out; +} + +async function runOne(device: GPUDevice, sm: ShaderManager): Promise { + const totalSlots = NUM_SUBTASKS * INPUT_SIZE; + const totalBuckets = NUM_SUBTASKS * NUM_COLUMNS; + log( + 'info', + `=== T=${NUM_SUBTASKS} columns=${NUM_COLUMNS} input_size=${INPUT_SIZE} ` + + `total_slots=${totalSlots} total_buckets=${totalBuckets} WG=${WG} DISP=${DISP} REPS=${REPS}`, + ); + + // Synthetic packed bases: 32 bytes per element. Random bytes are fine + // — the converter is a layout-equivalent byte copy. + const rng = makeRng(SEED ^ 0xfeed); + const basesBytes = INPUT_SIZE * 32; + const basesX = new Uint8Array(basesBytes); + const basesY = new Uint8Array(basesBytes); + for (let i = 0; i < basesBytes; i += 4) { + const v = rng(); + basesX[i] = v & 0xff; + basesX[i + 1] = (v >>> 8) & 0xff; + basesX[i + 2] = (v >>> 16) & 0xff; + basesX[i + 3] = (v >>> 24) & 0xff; + const w = rng(); + basesY[i] = w & 0xff; + basesY[i + 1] = (w >>> 8) & 0xff; + basesY[i + 2] = (w >>> 16) & 0xff; + basesY[i + 3] = (w >>> 24) & 0xff; + } + log('info', `synthetic packed bases: ${(basesBytes * 2 / 1024).toFixed(1)} KiB total`); + + const { rowPtr, valIdx } = buildSyntheticCsr(NUM_SUBTASKS, NUM_COLUMNS, INPUT_SIZE, SEED); + log('info', `synthetic CSR: rowPtr=${(rowPtr.byteLength / 1024).toFixed(1)} KiB valIdx=${(valIdx.byteLength / 1024).toFixed(1)} KiB`); + + const mk = (bytes: number, copySrc = false, copyDst = false): GPUBuffer => { + let usage = GPUBufferUsage.STORAGE; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + return device.createBuffer({ size: bytes, usage }); + }; + + const basesXBuf = mk(basesBytes, false, true); + const basesYBuf = mk(basesBytes, false, true); + const valIdxBuf = mk(valIdx.byteLength, false, true); + const rowPtrBuf = mk(rowPtr.byteLength, false, true); + const activeXBuf = mk(totalSlots * 32, true); + const activeYBuf = mk(totalSlots * 32, true); + const countsBuf = mk(totalBuckets * 4, true); + const offsetsBuf = mk(totalBuckets * 4, true); + device.queue.writeBuffer(basesXBuf, 0, basesX); + device.queue.writeBuffer(basesYBuf, 0, basesY); + device.queue.writeBuffer(valIdxBuf, 0, valIdx); + device.queue.writeBuffer(rowPtrBuf, 0, rowPtr); + + const paramsActiveBuf = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(paramsActiveBuf, 0, new Uint32Array([totalSlots, 0, 0, 0])); + const paramsMetaBuf = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(paramsMetaBuf, 0, new Uint32Array([NUM_COLUMNS, totalBuckets, 0, 0])); + + const activeLayout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 5, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const metaLayout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const activePipeline = await compileOne(device, sm.gen_csr_to_v2_active_sums_shader(WG), `csr-to-v2-active_sums-wg${WG}`, activeLayout); + const metaPipeline = await compileOne(device, sm.gen_csr_to_v2_meta_shader(WG), `csr-to-v2-meta-wg${WG}`, metaLayout); + + const activeBind = device.createBindGroup({ + layout: activeLayout, + entries: [ + { binding: 0, resource: { buffer: valIdxBuf } }, + { binding: 1, resource: { buffer: basesXBuf } }, + { binding: 2, resource: { buffer: basesYBuf } }, + { binding: 3, resource: { buffer: activeXBuf } }, + { binding: 4, resource: { buffer: activeYBuf } }, + { binding: 5, resource: { buffer: paramsActiveBuf } }, + ], + }); + const metaBind = device.createBindGroup({ + layout: metaLayout, + entries: [ + { binding: 0, resource: { buffer: rowPtrBuf } }, + { binding: 1, resource: { buffer: countsBuf } }, + { binding: 2, resource: { buffer: offsetsBuf } }, + { binding: 3, resource: { buffer: paramsMetaBuf } }, + ], + }); + + const groupsActive = Math.ceil(totalSlots / WG); + const groupsMeta = Math.ceil(totalBuckets / WG); + log('info', `dispatch shape: active=${groupsActive} wgs, meta=${groupsMeta} wgs`); + + // Warmup + { + const enc = device.createCommandEncoder(); + let pass = enc.beginComputePass(); + pass.setPipeline(activePipeline); + pass.setBindGroup(0, activeBind); + pass.dispatchWorkgroups(groupsActive, 1, 1); + pass.end(); + pass = enc.beginComputePass(); + pass.setPipeline(metaPipeline); + pass.setBindGroup(0, metaBind); + pass.dispatchWorkgroups(groupsMeta, 1, 1); + pass.end(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + log('info', 'warmup done'); + + let validated = false; + if (VALIDATE) { + const gpuActiveX = await readbackU32(device, activeXBuf, totalSlots * 32); + const gpuActiveY = await readbackU32(device, activeYBuf, totalSlots * 32); + const gpuCounts = await readbackU32(device, countsBuf, totalBuckets * 4); + const gpuOffsets = await readbackU32(device, offsetsBuf, totalBuckets * 4); + const refBasesX = new Uint32Array(basesX.buffer, basesX.byteOffset, INPUT_SIZE * 8); + const refBasesY = new Uint32Array(basesY.buffer, basesY.byteOffset, INPUT_SIZE * 8); + + const mismatches: string[] = []; + let xFails = 0; + let yFails = 0; + for (let s = 0; s < NUM_SUBTASKS && mismatches.length < 16; s++) { + const viBase = s * INPUT_SIZE; + const probeStart = Math.max(0, INPUT_SIZE - 4); + for (const k of [0, 1, 7, INPUT_SIZE >> 1, probeStart, INPUT_SIZE - 1]) { + const slot = viBase + k; + const ptIdx = valIdx[slot]; + for (let w = 0; w < 8; w++) { + const got = gpuActiveX[slot * 8 + w]; + const want = refBasesX[ptIdx * 8 + w]; + if (got !== want) { + xFails++; + if (mismatches.length < 8) mismatches.push(`activeX[s=${s} k=${k} w=${w}]: gpu=${got} ref=${want}`); + } + const gotY = gpuActiveY[slot * 8 + w]; + const wantY = refBasesY[ptIdx * 8 + w]; + if (gotY !== wantY) { + yFails++; + if (mismatches.length < 8) mismatches.push(`activeY[s=${s} k=${k} w=${w}]: gpu=${gotY} ref=${wantY}`); + } + } + } + } + // Full-pass byte compare (cheap; the buffers are u32 arrays). + for (let k = 0; k < totalSlots * 8; k++) { + const ptIdx = valIdx[k >> 3]; + const w = k & 7; + if (gpuActiveX[k] !== refBasesX[ptIdx * 8 + w]) xFails++; + if (gpuActiveY[k] !== refBasesY[ptIdx * 8 + w]) yFails++; + } + let mFails = 0; + for (let s = 0; s < NUM_SUBTASKS; s++) { + const rpBase = s * (NUM_COLUMNS + 1); + const ccBase = s * NUM_COLUMNS; + for (let b = 0; b < NUM_COLUMNS; b++) { + const wantCount = rowPtr[rpBase + b + 1] - rowPtr[rpBase + b]; + const wantOffset = rowPtr[rpBase + b]; + if (gpuCounts[ccBase + b] !== wantCount) { + mFails++; + if (mismatches.length < 12) mismatches.push(`counts[s=${s} b=${b}]: gpu=${gpuCounts[ccBase + b]} ref=${wantCount}`); + } + if (gpuOffsets[ccBase + b] !== wantOffset) { + mFails++; + if (mismatches.length < 12) mismatches.push(`offsets[s=${s} b=${b}]: gpu=${gpuOffsets[ccBase + b]} ref=${wantOffset}`); + } + } + } + if (xFails === 0 && yFails === 0 && mFails === 0) { + validated = true; + log('ok', `validation: PASS — converter output byte-equivalent to host reference (${totalSlots} slots, ${totalBuckets} buckets)`); + } else { + log('err', `validation: FAIL — activeX mismatches=${xFails}, activeY=${yFails}, meta=${mFails}`); + for (const m of mismatches.slice(0, 12)) log('err', ` ${m}`); + } + } + + // Timed: DISP back-to-back encoded as one submit per rep, split active vs meta. + const activeSamples: number[] = []; + const metaSamples: number[] = []; + const combinedSamples: number[] = []; + for (let r = 0; r < REPS; r++) { + { + const enc = device.createCommandEncoder(); + for (let d = 0; d < DISP; d++) { + const pass = enc.beginComputePass(); + pass.setPipeline(activePipeline); + pass.setBindGroup(0, activeBind); + pass.dispatchWorkgroups(groupsActive, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + activeSamples.push(performance.now() - t0); + } + { + const enc = device.createCommandEncoder(); + for (let d = 0; d < DISP; d++) { + const pass = enc.beginComputePass(); + pass.setPipeline(metaPipeline); + pass.setBindGroup(0, metaBind); + pass.dispatchWorkgroups(groupsMeta, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + metaSamples.push(performance.now() - t0); + } + { + const enc = device.createCommandEncoder(); + for (let d = 0; d < DISP; d++) { + let pass = enc.beginComputePass(); + pass.setPipeline(activePipeline); + pass.setBindGroup(0, activeBind); + pass.dispatchWorkgroups(groupsActive, 1, 1); + pass.end(); + pass = enc.beginComputePass(); + pass.setPipeline(metaPipeline); + pass.setBindGroup(0, metaBind); + pass.dispatchWorkgroups(groupsMeta, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + combinedSamples.push(performance.now() - t0); + } + } + const stat = (xs: number[]): { min: number; median: number; max: number } => { + const med = median(xs); + return { min: (Math.min(...xs) / DISP) * 1000, median: (med / DISP) * 1000, max: (Math.max(...xs) / DISP) * 1000 }; + }; + const activeStats = stat(activeSamples); + const metaStats = stat(metaSamples); + const combinedStats = stat(combinedSamples); + const nsPerSlotCombined = (combinedStats.median * 1000) / totalSlots; + log( + 'ok', + `active_sums per-dispatch: median=${activeStats.median.toFixed(2)}μs min=${activeStats.min.toFixed(2)}μs max=${activeStats.max.toFixed(2)}μs`, + ); + log( + 'ok', + `meta per-dispatch: median=${metaStats.median.toFixed(2)}μs min=${metaStats.min.toFixed(2)}μs max=${metaStats.max.toFixed(2)}μs`, + ); + log( + 'ok', + `combined per-dispatch: median=${combinedStats.median.toFixed(2)}μs min=${combinedStats.min.toFixed(2)}μs max=${combinedStats.max.toFixed(2)}μs ` + + `→ ${nsPerSlotCombined.toFixed(3)} ns/slot (${totalSlots} slots)`, + ); + + basesXBuf.destroy(); + basesYBuf.destroy(); + valIdxBuf.destroy(); + rowPtrBuf.destroy(); + activeXBuf.destroy(); + activeYBuf.destroy(); + countsBuf.destroy(); + offsetsBuf.destroy(); + paramsActiveBuf.destroy(); + paramsMetaBuf.destroy(); + + return { + num_subtasks: NUM_SUBTASKS, + num_columns: NUM_COLUMNS, + input_size: INPUT_SIZE, + total_slots: totalSlots, + total_buckets: totalBuckets, + wg: WG, + disp: DISP, + reps: REPS, + active_sums_us: activeStats, + meta_us: metaStats, + combined_us: combinedStats, + ns_per_slot_combined: nsPerSlotCombined, + validated, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + if (qp.get('subtasks')) NUM_SUBTASKS = parseInt(qp.get('subtasks')!, 10); + if (qp.get('columns')) NUM_COLUMNS = parseInt(qp.get('columns')!, 10); + if (qp.get('input')) INPUT_SIZE = parseInt(qp.get('input')!, 10); + if (qp.get('wg')) WG = parseInt(qp.get('wg')!, 10); + if (qp.get('disp')) DISP = parseInt(qp.get('disp')!, 10); + if (qp.get('reps')) REPS = parseInt(qp.get('reps')!, 10); + if (qp.get('seed')) SEED = parseInt(qp.get('seed')!, 10); + if (qp.get('validate') === '1') VALIDATE = true; + return { + subtasks: NUM_SUBTASKS, + columns: NUM_COLUMNS, + input_size: INPUT_SIZE, + wg: WG, + disp: DISP, + reps: REPS, + seed: SEED, + validate: VALIDATE, + }; +} + +async function main() { + try { + if (!('gpu' in navigator)) throw new Error('navigator.gpu missing'); + const params = parseParams(); + benchState.params = params; + log('info', `params: ${JSON.stringify(params)}`); + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + const sm = new ShaderManager(NUM_SUBTASKS, NUM_COLUMNS, BN254_CURVE_CONFIG, false); + const r = await runOne(device, sm); + benchState.results.push(r); + resultsClient.postProgress({ + kind: 'csr_to_v2_done', + active_sums_us: r.active_sums_us, + meta_us: r.meta_us, + combined_us: r.combined_us, + ns_per_slot_combined: r.ns_per_slot_combined, + validated: r.validated, + }); + benchState.state = 'done'; + log('ok', 'done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + }) + .finally(() => { + postFinal().catch(() => {}); + }); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs index 44247b10b8ab..f89804e80ce6 100644 --- a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs +++ b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs @@ -135,6 +135,7 @@ const pageMap = { "bench-msm-tree-v2": "/dev/msm-webgpu/bench-msm-tree-v2.html", "bench-msm-tree-v3": "/dev/msm-webgpu/bench-msm-tree-v3.html", "bench-planner": "/dev/msm-webgpu/bench-planner.html", + "bench-csr-to-v2": "/dev/msm-webgpu/bench-csr-to-v2.html", "bench-smvp-tree": "/dev/msm-webgpu/bench-smvp-tree.html", sanity: "/dev/msm-webgpu/index.html", }; diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index 24fe0121c2f4..ab28aa1b4425 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -40,6 +40,8 @@ import { bpr_bn254 as bpr_bn254_shader, convert_point_coords_and_decompose_scalars, convert_points_only as convert_points_only_shader, + csr_to_v2_active_sums as csr_to_v2_active_sums_shader, + csr_to_v2_meta as csr_to_v2_meta_shader, decompose_scalars_signed_only as decompose_scalars_signed_only_shader, decompress_g1_bn254 as decompress_g1_bn254_shader, divsteps_bench as divsteps_bench_shader, @@ -894,6 +896,39 @@ ${packLines.join('\n')} ); } + /** + * Layout converter (active_sums materialization): copies packed + * 8×u32 base coords from cached_bases into bucket-major active_sums + * indexed by val_idx. One thread per (subtask, slot). Pure raw vec4 + * copy — no field-element math. + */ + public gen_csr_to_v2_active_sums_shader(workgroup_size: number): string { + if (workgroup_size <= 0 || !Number.isInteger(workgroup_size)) { + throw new Error(`gen_csr_to_v2_active_sums_shader: workgroup_size (${workgroup_size}) must be a positive integer`); + } + return mustache.render( + csr_to_v2_active_sums_shader, + { workgroup_size, recompile: this.recompile }, + {}, + ); + } + + /** + * Layout converter (meta derivation): writes per-bucket count and + * subtask-relative offset from cuZK row_ptr. One thread per + * (subtask, bucket). + */ + public gen_csr_to_v2_meta_shader(workgroup_size: number): string { + if (workgroup_size <= 0 || !Number.isInteger(workgroup_size)) { + throw new Error(`gen_csr_to_v2_meta_shader: workgroup_size (${workgroup_size}) must be a positive integer`); + } + return mustache.render( + csr_to_v2_meta_shader, + { workgroup_size, recompile: this.recompile }, + {}, + ); + } + /** * Bin-packed pair-tree: carry-copy kernel. Propagates the odd-count * carry element forward to the next level without modification. diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index 232ae7a49364..60ee225eba91 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1,6 +1,6 @@ // AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT. // Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate. -// 62 shader sources inlined. +// 64 shader sources inlined. /* eslint-disable */ @@ -6281,6 +6281,108 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { } `; +export const csr_to_v2_active_sums = `// Layout converter for the v2 pair-tree MSM bucket-accumulate path. +// +// Materializes the bucket-major active_sums buffer by copying packed +// 8×u32 base coords from the cached_bases (new_point_x / new_point_y) +// at the indices listed in val_idx (cuZK transpose output, bucket-major +// per subtask). +// +// Per (subtask s, slot k) thread, with slot = s * input_size + k: +// pt_idx = val_idx[slot] +// active_sums_x[slot] = new_point_x[pt_idx] +// active_sums_y[slot] = new_point_y[pt_idx] +// +// Both source and destination are packed 8×u32 (two vec4 per field +// element). The copy is a raw element copy — destination element bytes +// equal source element bytes; no unpack / pack needed. +// +// Sign handling: cuZK encodes signed slices via bucket index, not via +// point negation, so the converter does not flip y. The finalize pass +// negates y for negative-bucket contributions. + +@group(0) @binding(0) +var val_idx: array; +@group(0) @binding(1) +var new_point_x: array>; +@group(0) @binding(2) +var new_point_y: array>; +@group(0) @binding(3) +var active_sums_x: array>; +@group(0) @binding(4) +var active_sums_y: array>; + +// params[0] = total_slots (num_subtasks * input_size) +@group(0) @binding(5) +var params: vec4; + +@compute +@workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let slot = gid.x; + let total = params[0]; + if (slot >= total) { + return; + } + + let pt_idx = val_idx[slot]; + + active_sums_x[2u * slot] = new_point_x[2u * pt_idx]; + active_sums_x[2u * slot + 1u] = new_point_x[2u * pt_idx + 1u]; + active_sums_y[2u * slot] = new_point_y[2u * pt_idx]; + active_sums_y[2u * slot + 1u] = new_point_y[2u * pt_idx + 1u]; + + {{{ recompile }}} +} +`; + +export const csr_to_v2_meta = `// Companion to csr_to_v2_active_sums: derives the per-bucket counts and +// subtask-relative offsets that drive the v2 pair-tree planner. +// +// row_ptr layout: per subtask, num_columns + 1 entries forming a +// CSR-style prefix sum. row_ptr[s * (num_columns + 1) + b + 1] - +// row_ptr[s * (num_columns + 1) + b] is the count of points in bucket +// b of subtask s, and the begin value is the subtask-relative start +// offset within val_idx and active_sums. +// +// One thread per (subtask, bucket) emits one (count, offset) pair. + +@group(0) @binding(0) +var row_ptr: array; +@group(0) @binding(1) +var active_counts: array; +@group(0) @binding(2) +var active_offsets: array; + +// params[0] = num_columns +// params[1] = total_buckets (num_subtasks * num_columns) +@group(0) @binding(3) +var params: vec4; + +@compute +@workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let id = gid.x; + let total = params[1]; + if (id >= total) { + return; + } + + let num_columns = params[0]; + let subtask = id / num_columns; + let bucket_local = id % num_columns; + let rp_offset = subtask * (num_columns + 1u); + + let begin = row_ptr[rp_offset + bucket_local]; + let end = row_ptr[rp_offset + bucket_local + 1u]; + + active_counts[id] = end - begin; + active_offsets[id] = begin; + + {{{ recompile }}} +} +`; + export const decompose_scalars_signed_only = `// Scalars-only variant of \`convert_point_coords_and_decompose_scalars\`. // Reads 32-byte LE scalars from a packed u32 buffer and writes one // shifted-signed bucket index per scalar per subtask into \`chunks\`. diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_active_sums.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_active_sums.template.wgsl new file mode 100644 index 000000000000..e52219756a56 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_active_sums.template.wgsl @@ -0,0 +1,53 @@ +// Layout converter for the v2 pair-tree MSM bucket-accumulate path. +// +// Materializes the bucket-major active_sums buffer by copying packed +// 8×u32 base coords from the cached_bases (new_point_x / new_point_y) +// at the indices listed in val_idx (cuZK transpose output, bucket-major +// per subtask). +// +// Per (subtask s, slot k) thread, with slot = s * input_size + k: +// pt_idx = val_idx[slot] +// active_sums_x[slot] = new_point_x[pt_idx] +// active_sums_y[slot] = new_point_y[pt_idx] +// +// Both source and destination are packed 8×u32 (two vec4 per field +// element). The copy is a raw element copy — destination element bytes +// equal source element bytes; no unpack / pack needed. +// +// Sign handling: cuZK encodes signed slices via bucket index, not via +// point negation, so the converter does not flip y. The finalize pass +// negates y for negative-bucket contributions. + +@group(0) @binding(0) +var val_idx: array; +@group(0) @binding(1) +var new_point_x: array>; +@group(0) @binding(2) +var new_point_y: array>; +@group(0) @binding(3) +var active_sums_x: array>; +@group(0) @binding(4) +var active_sums_y: array>; + +// params[0] = total_slots (num_subtasks * input_size) +@group(0) @binding(5) +var params: vec4; + +@compute +@workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let slot = gid.x; + let total = params[0]; + if (slot >= total) { + return; + } + + let pt_idx = val_idx[slot]; + + active_sums_x[2u * slot] = new_point_x[2u * pt_idx]; + active_sums_x[2u * slot + 1u] = new_point_x[2u * pt_idx + 1u]; + active_sums_y[2u * slot] = new_point_y[2u * pt_idx]; + active_sums_y[2u * slot + 1u] = new_point_y[2u * pt_idx + 1u]; + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_meta.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_meta.template.wgsl new file mode 100644 index 000000000000..fdc595a0b3bc --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_meta.template.wgsl @@ -0,0 +1,45 @@ +// Companion to csr_to_v2_active_sums: derives the per-bucket counts and +// subtask-relative offsets that drive the v2 pair-tree planner. +// +// row_ptr layout: per subtask, num_columns + 1 entries forming a +// CSR-style prefix sum. row_ptr[s * (num_columns + 1) + b + 1] - +// row_ptr[s * (num_columns + 1) + b] is the count of points in bucket +// b of subtask s, and the begin value is the subtask-relative start +// offset within val_idx and active_sums. +// +// One thread per (subtask, bucket) emits one (count, offset) pair. + +@group(0) @binding(0) +var row_ptr: array; +@group(0) @binding(1) +var active_counts: array; +@group(0) @binding(2) +var active_offsets: array; + +// params[0] = num_columns +// params[1] = total_buckets (num_subtasks * num_columns) +@group(0) @binding(3) +var params: vec4; + +@compute +@workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let id = gid.x; + let total = params[1]; + if (id >= total) { + return; + } + + let num_columns = params[0]; + let subtask = id / num_columns; + let bucket_local = id % num_columns; + let rp_offset = subtask * (num_columns + 1u); + + let begin = row_ptr[rp_offset + bucket_local]; + let end = row_ptr[rp_offset + bucket_local + 1u]; + + active_counts[id] = end - begin; + active_offsets[id] = begin; + + {{{ recompile }}} +} From a7caeb3dea9b379446602dc945dfbccc7c49b2ae Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 13:58:44 +0000 Subject: [PATCH 22/33] feat(bb/msm): noble-CPU end-to-end oracle for v2 pair-tree MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Step 2 of the v2 pair-tree MSM integration: a tiny standalone oracle that runs the bin-packed pair-tree bucket-accumulate on REAL BN254 points and verifies each per-bucket reduced sum against a noble projective reference. The test fused_revcarry never had. Validates the v2 pair-tree's round- kernel math end-to-end on real curve data — disjoint pair-sum, suffix- product single-fr_inv per chunk, lean affine add formula, and the multi-level bin-packed reduction. Scope: bucket-accumulate phase only. BPR / horner / finalize are deferred to later steps of the rewrite. Sizing: N=256 points, B=32 buckets, single window — small enough for noble's projective add per bucket to run instantly in the browser. What's in the harness: - dev/msm-webgpu/bench-msm-oracle.{ts,html} — generates N random BN254 affine points via @noble/curves (k*G with random k ∈ [1, order)), converts (x,y) canonical → Montgomery (R = 2^260 mod p, matching compute_misc_params(p, 13)), lays out bucket-major SoA + pad pair exactly as the v2 pair-tree expects, runs the existing marshal / disjoint / scatter / carry-copy pipeline single-submit, reads back, converts each non-empty bucket's reduced sum back Mont → canonical via R^-1, and cross-checks (x, y) byte-equal to noble's affine sum of the original points assigned to that bucket. - scripts/run-browserstack.mjs — bench-msm-oracle entry in pageMap so it's reachable from the BS M2 runner. Query params: ?n=256&buckets=32&s=16&wgi=64&seed=K (defaults shown). This commit only ships the harness; the BS M2 oracle run is the validation gate I will trigger next. --- .../ts/dev/msm-webgpu/bench-msm-oracle.html | 22 + .../ts/dev/msm-webgpu/bench-msm-oracle.ts | 622 ++++++++++++++++++ .../msm-webgpu/scripts/run-browserstack.mjs | 1 + 3 files changed, 645 insertions(+) create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.html create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.ts diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.html new file mode 100644 index 000000000000..8296a4e0844b --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.html @@ -0,0 +1,22 @@ + + + + + v2 pair-tree bucket-accumulate noble-CPU oracle (WebGPU) + + + +

v2 pair-tree bucket-accumulate noble-CPU oracle

+

Query params: ?n=N&buckets=B&s=S&wgi=W&seed=K

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.ts new file mode 100644 index 000000000000..a2ab8585ec09 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.ts @@ -0,0 +1,622 @@ +/// +// End-to-end correctness oracle for the v2 bin-packed pair-tree +// bucket-accumulate pipeline. Feeds REAL BN254 affine points (random +// scalar * G, via @noble/curves) into the pair-tree and verifies that +// the per-bucket reduced sum matches a noble-projective reference. +// +// This is the test fused_revcarry never had: a ground-truth oracle on +// real curve data. If this passes, the v2 pair-tree's bucket-accumulate +// math (disjoint pair-sum + suffix-product single-fr_inv + lean affine +// add) is correct end-to-end on real BN254 points. +// +// Scope: validates the round kernel + planner only. BPR / horner / +// finalize are NOT part of v2 yet — they're step 3+ of the rewrite plan. +// This oracle stops at "per-bucket sum is correct". +// +// Sizing: tiny by design — N=256 points, B=32 buckets, single window +// (no signed slicing). Each bucket gets ~8 points; the pair-tree +// reduces in ~4 levels. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD, modInverse } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { makeResultsClient } from './results_post.js'; +import { bn254 } from '@noble/curves/bn254'; + +const PG = 2; +let NPTS = 256; +let BUCKETS = 32; +let S = 16; +let WGI = 64; +let SEED = 0xa110ce; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function bigintToPackedU32x8(v: bigint): Uint32Array { + const w = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + w[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return w; +} + +function packedU32x8ToBigint(w: Uint32Array, off: number): bigint { + let v = 0n; + for (let i = 7; i >= 0; i--) v = (v << 32n) | BigInt(w[off + i] >>> 0); + return v; +} + +function makeSoABuf(device: GPUDevice, M: number, copyDst: boolean, copySrc: boolean): GPUBuffer { + const bytes = 2 * PG * M * 4 * 4; + let usage = GPUBufferUsage.STORAGE; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + return device.createBuffer({ size: bytes, usage }); +} + +interface CurvePoint { + x: bigint; + y: bigint; +} + +function buildL0WithRealPoints( + N: number, + B: number, + R: bigint, + p: bigint, + rng: () => number, +): { + initBuf: Uint32Array; + initCounts: Uint32Array; + initOffsets: Uint32Array; + M: number; + points: CurvePoint[]; + bucket: Uint32Array; +} { + const M = N + 2; + const buf = new Uint32Array(2 * PG * M * 4); + const G1 = bn254.G1.ProjectivePoint; + const order = bn254.fields.Fr.ORDER; + + const points: CurvePoint[] = []; + const xWords = new Uint32Array(8 * M); + const yWords = new Uint32Array(8 * M); + + for (let i = 0; i < N; i++) { + let k = 0n; + for (let w = 0; w < 8; w++) k = (k << 32n) | BigInt(rng() >>> 0); + k = k % order; + if (k === 0n) k = 1n; + const aff = G1.BASE.multiply(k).toAffine(); + points.push({ x: aff.x, y: aff.y }); + const xMont = (aff.x * R) % p; + const yMont = (aff.y * R) % p; + xWords.set(bigintToPackedU32x8(xMont), 8 * i); + yWords.set(bigintToPackedU32x8(yMont), 8 * i); + } + + for (let pad = 0; pad < 2; pad++) { + const i = N + pad; + let xCand: bigint; + do { + xCand = 0n; + for (let w = 0; w < 8; w++) xCand = (xCand << 32n) | BigInt(rng() >>> 0); + xCand = xCand % p; + } while (xCand === 0n); + const yCand = ((xCand + 1n + BigInt(pad)) * 7n) % p; + const xMont = (xCand * R) % p; + const yMont = (yCand * R) % p; + xWords.set(bigintToPackedU32x8(xMont), 8 * i); + yWords.set(bigintToPackedU32x8(yMont), 8 * i); + } + + const bucket = new Uint32Array(N); + const counts = new Uint32Array(B); + for (let i = 0; i < N; i++) { + const hi = (rng() >>> 16) & 0xffff; + const lo = (rng() >>> 16) & 0xffff; + const v = hi * 0x10000 + lo; + const b = v % B; + bucket[i] = b; + counts[b]++; + } + const offsets = new Uint32Array(B + 1); + for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b]; + + const cursor = new Uint32Array(B); + const writeElem = (planeIdx: number, dstIdx: number, words: Uint32Array, srcOff: number) => { + for (let v = 0; v < PG; v++) { + const base = ((planeIdx * PG + v) * M + dstIdx) * 4; + buf[base + 0] = words[srcOff + 4 * v + 0]; + buf[base + 1] = words[srcOff + 4 * v + 1]; + buf[base + 2] = words[srcOff + 4 * v + 2]; + buf[base + 3] = words[srcOff + 4 * v + 3]; + } + }; + for (let i = 0; i < N; i++) { + const b = bucket[i]; + const dst = offsets[b] + cursor[b]++; + writeElem(0, dst, xWords, 8 * i); + writeElem(1, dst, yWords, 8 * i); + } + writeElem(0, M - 2, xWords, 8 * (M - 2)); + writeElem(1, M - 2, yWords, 8 * (M - 2)); + writeElem(0, M - 1, xWords, 8 * (M - 1)); + writeElem(1, M - 1, yWords, 8 * (M - 1)); + + return { initBuf: buf, initCounts: counts, initOffsets: offsets, M, points, bucket }; +} + +function buildLevelPlan( + counts: Uint32Array, + offsets: Uint32Array, + s: number, + padLIdx: number, + padRIdx: number, + discardIdx: number, +) { + const B = counts.length; + let totalPairs = 0; + let totalCarries = 0; + const newCounts = new Uint32Array(B); + for (let b = 0; b < B; b++) { + const n = counts[b]; + const p = Math.floor(n / 2); + const c = n & 1; + totalPairs += p; + totalCarries += c; + newCounts[b] = p + c; + } + const newOffsets = new Uint32Array(B + 1); + for (let b = 0; b < B; b++) newOffsets[b + 1] = newOffsets[b] + newCounts[b]; + + const numChunks = Math.max(1, Math.ceil(totalPairs / s)); + const chunkPlan = new Uint32Array(2 * s * numChunks); + const scatterPlan = new Uint32Array(s * numChunks); + const carryPlan = new Uint32Array(2 * Math.max(1, totalCarries)); + + for (let i = 0; i < numChunks * s; i++) { + chunkPlan[2 * i + 0] = padLIdx; + chunkPlan[2 * i + 1] = padRIdx; + scatterPlan[i] = discardIdx; + } + + let slot = 0; + let carryIdx = 0; + for (let b = 0; b < B; b++) { + const n = counts[b]; + const p = Math.floor(n / 2); + for (let j = 0; j < p; j++) { + chunkPlan[2 * slot + 0] = offsets[b] + 2 * j; + chunkPlan[2 * slot + 1] = offsets[b] + 2 * j + 1; + scatterPlan[slot] = newOffsets[b] + j; + slot++; + } + if (n & 1) { + carryPlan[2 * carryIdx + 0] = offsets[b] + n - 1; + carryPlan[2 * carryIdx + 1] = newOffsets[b] + p; + carryIdx++; + } + } + return { chunkPlan, scatterPlan, carryPlan, newCounts, newOffsets, numChunks, numCarries: totalCarries, totalPairs }; +} + +interface BucketCheck { + bucket: number; + count: number; + gpu_x: string; + gpu_y: string; + ref_x: string; + ref_y: string; + ok: boolean; +} + +interface OracleResult { + n: number; + buckets: number; + s: number; + wgi: number; + levels: number; + total_pair_adds: number; + buckets_checked: number; + buckets_passed: number; + first_mismatches: BucketCheck[]; + all_passed: boolean; + gpu_wall_ms: number; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: Record | null; + results: OracleResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] }; +(window as unknown as { __bench: BenchState }).__bench = benchState; +const resultsClient = makeResultsClient({ page: 'bench-msm-oracle' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, + params: benchState.params, + results: benchState.results, + error: benchState.error, + log: benchState.log, + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-msm-oracle] ${msg}`); +} + +async function compileOne(device: GPUDevice, code: string, key: string, layout: GPUBindGroupLayout): Promise { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const m of info.messages) { + const line = `[shader ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`; + if (m.type === 'error') { + console.error(line); + log('err', line); + errLines.push(line); + hasError = true; + } else { + console.warn(line); + } + } + if (hasError) throw new Error(`WGSL compile failed for ${key}: ${errLines.slice(0, 4).join(' | ')}`); + return device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); +} + +function ioLayout4(device: GPUDevice): GPUBindGroupLayout { + return device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); +} + +async function readbackU32(device: GPUDevice, buf: GPUBuffer, bytes: number): Promise { + const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, bytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const out = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + return out; +} + +async function runOracle(device: GPUDevice, sm: ShaderManager, R: bigint, Rinv: bigint, p: bigint): Promise { + log('info', `=== N=${NPTS} B=${BUCKETS} S=${S} WGI=${WGI}`); + const rng = makeRng(SEED); + const { initBuf, initCounts, initOffsets, M, points, bucket } = buildL0WithRealPoints(NPTS, BUCKETS, R, p, rng); + + let cMin = NPTS, cMax = 0, cZero = 0; + for (let b = 0; b < BUCKETS; b++) { + if (initCounts[b] > cMax) cMax = initCounts[b]; + if (initCounts[b] < cMin) cMin = initCounts[b]; + if (initCounts[b] === 0) cZero++; + } + log('info', `built L0: M=${M} bucket counts min=${cMin} max=${cMax} zero=${cZero}/${BUCKETS}`); + + const padLIdx = M - 2; + const padRIdx = M - 1; + const discardIdx = M - 2; + + const bufA = makeSoABuf(device, M, true, true); + const bufB = makeSoABuf(device, M, true, true); + device.queue.writeBuffer(bufA, 0, initBuf); + device.queue.writeBuffer(bufB, 0, initBuf); + + const maxL0Chunks = Math.ceil(NPTS / 2 / S) + 1; + const chainBuf = makeSoABuf(device, 2 * S * maxL0Chunks, false, false); + const tempOutBuf = makeSoABuf(device, S * maxL0Chunks, false, true); + + const layoutMarshal = ioLayout4(device); + const layoutDisjoint = ioLayout4(device); + const layoutScatter = ioLayout4(device); + const layoutCarry = ioLayout4(device); + const marshalPipe = await compileOne(device, sm.gen_ba_marshal_pairs_bench_shader(WGI, S), `marshal-W${WGI}-S${S}`, layoutMarshal); + const disjointPipe = await compileOne(device, sm.gen_ba_pair_disjoint_tree_bench_shader(WGI, S), `disjoint-W${WGI}-S${S}`, layoutDisjoint); + const scatterPipe = await compileOne(device, sm.gen_ba_scatter_pairs_bench_shader(WGI, S), `scatter-W${WGI}-S${S}`, layoutScatter); + const carryPipe = await compileOne(device, sm.gen_ba_carry_copy_bench_shader(WGI), `carry-W${WGI}`, layoutCarry); + log('info', '4 pipelines compiled'); + + let counts = initCounts; + let offsets = initOffsets; + let finalOffsets: Uint32Array = initOffsets; + let curIn: GPUBuffer = bufA; + let curOut: GPUBuffer = bufB; + let totalPairAdds = 0; + let levelIdx = 0; + const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE }); + + interface PassSpec { pipeline: GPUComputePipeline; bind: GPUBindGroup; numWgs: number } + const allPasses: PassSpec[] = []; + const levelBufHolders: GPUBuffer[] = []; + + for (;;) { + let maxCount = 0; + for (let b = 0; b < counts.length; b++) if (counts[b] > maxCount) maxCount = counts[b]; + if (maxCount <= 1) { + finalOffsets = offsets; + break; + } + if (levelIdx > 24) throw new Error('exceeded safety level cap'); + + const plan = buildLevelPlan(counts, offsets, S, padLIdx, padRIdx, discardIdx); + totalPairAdds += plan.totalPairs; + const T = plan.numChunks; + const numWgs = Math.ceil(T / WGI); + log('info', `L${levelIdx}: T=${T} pairs=${plan.totalPairs} carries=${plan.numCarries} maxCount=${maxCount}`); + + const chunkPlanBuf = device.createBuffer({ size: plan.chunkPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(chunkPlanBuf, 0, plan.chunkPlan); + const scatterPlanBuf = device.createBuffer({ size: plan.scatterPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(scatterPlanBuf, 0, plan.scatterPlan); + const carryPlanBuf = device.createBuffer({ size: plan.carryPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(carryPlanBuf, 0, plan.carryPlan); + + const marshalParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(marshalParams, 0, new Uint32Array([T, M, 0, 0])); + const disjointParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(disjointParams, 0, new Uint32Array([2 * S * T, T, 1, 0])); + const scatterParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(scatterParams, 0, new Uint32Array([T, M, 0, 0])); + const carryParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + if (plan.numCarries > 0) { + device.queue.writeBuffer(carryParams, 0, new Uint32Array([plan.numCarries, M, M, 0])); + } + + const marshalBind = device.createBindGroup({ + layout: layoutMarshal, + entries: [ + { binding: 0, resource: { buffer: chunkPlanBuf } }, + { binding: 1, resource: { buffer: curIn } }, + { binding: 2, resource: { buffer: chainBuf } }, + { binding: 3, resource: { buffer: marshalParams } }, + ], + }); + const disjointBind = device.createBindGroup({ + layout: layoutDisjoint, + entries: [ + { binding: 0, resource: { buffer: chainBuf } }, + { binding: 1, resource: { buffer: dummy } }, + { binding: 2, resource: { buffer: tempOutBuf } }, + { binding: 3, resource: { buffer: disjointParams } }, + ], + }); + const scatterBind = device.createBindGroup({ + layout: layoutScatter, + entries: [ + { binding: 0, resource: { buffer: scatterPlanBuf } }, + { binding: 1, resource: { buffer: tempOutBuf } }, + { binding: 2, resource: { buffer: curOut } }, + { binding: 3, resource: { buffer: scatterParams } }, + ], + }); + let carryBind: GPUBindGroup | null = null; + if (plan.numCarries > 0) { + carryBind = device.createBindGroup({ + layout: layoutCarry, + entries: [ + { binding: 0, resource: { buffer: carryPlanBuf } }, + { binding: 1, resource: { buffer: curIn } }, + { binding: 2, resource: { buffer: curOut } }, + { binding: 3, resource: { buffer: carryParams } }, + ], + }); + } + + allPasses.push({ pipeline: marshalPipe, bind: marshalBind, numWgs }); + allPasses.push({ pipeline: disjointPipe, bind: disjointBind, numWgs }); + allPasses.push({ pipeline: scatterPipe, bind: scatterBind, numWgs }); + if (plan.numCarries > 0 && carryBind) { + const carryWgs = Math.ceil(plan.numCarries / WGI); + allPasses.push({ pipeline: carryPipe, bind: carryBind, numWgs: carryWgs }); + } + levelBufHolders.push(chunkPlanBuf, scatterPlanBuf, carryPlanBuf, marshalParams, disjointParams, scatterParams, carryParams); + + counts = plan.newCounts; + offsets = plan.newOffsets; + [curIn, curOut] = [curOut, curIn]; + levelIdx++; + } + + const enc = device.createCommandEncoder(); + for (const ps of allPasses) { + const pass = enc.beginComputePass(); + pass.setPipeline(ps.pipeline); + pass.setBindGroup(0, ps.bind); + pass.dispatchWorkgroups(ps.numWgs, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + const gpuWall = performance.now() - t0; + log('info', `batched ${allPasses.length} passes in one submit: ${gpuWall.toFixed(2)} ms`); + + const result = await readbackU32(device, curIn, 2 * PG * M * 4 * 4); + + bufA.destroy(); + bufB.destroy(); + chainBuf.destroy(); + tempOutBuf.destroy(); + dummy.destroy(); + for (const b of levelBufHolders) b.destroy(); + + const decodeAt = (slot: number): { x_mont: bigint; y_mont: bigint } => { + const xWords = new Uint32Array(8); + const yWords = new Uint32Array(8); + for (let v = 0; v < PG; v++) { + const baseX = ((0 * PG + v) * M + slot) * 4; + const baseY = ((1 * PG + v) * M + slot) * 4; + xWords[4 * v + 0] = result[baseX + 0]; + xWords[4 * v + 1] = result[baseX + 1]; + xWords[4 * v + 2] = result[baseX + 2]; + xWords[4 * v + 3] = result[baseX + 3]; + yWords[4 * v + 0] = result[baseY + 0]; + yWords[4 * v + 1] = result[baseY + 1]; + yWords[4 * v + 2] = result[baseY + 2]; + yWords[4 * v + 3] = result[baseY + 3]; + } + const x_mont = packedU32x8ToBigint(xWords, 0); + const y_mont = packedU32x8ToBigint(yWords, 0); + return { x_mont, y_mont }; + }; + + const G1 = bn254.G1.ProjectivePoint; + const refSumPerBucket = new Map(); + for (let b = 0; b < BUCKETS; b++) { + if (initCounts[b] === 0) continue; + let acc = G1.ZERO; + for (let i = 0; i < NPTS; i++) { + if (bucket[i] !== b) continue; + acc = acc.add(G1.fromAffine({ x: points[i].x, y: points[i].y })); + } + refSumPerBucket.set(b, acc.is0() ? null : acc.toAffine()); + } + + const checks: BucketCheck[] = []; + const mismatches: BucketCheck[] = []; + let passCount = 0; + for (let b = 0; b < BUCKETS; b++) { + if (initCounts[b] === 0) continue; + const slot = finalOffsets[b]; + const { x_mont, y_mont } = decodeAt(slot); + const gx = (x_mont * Rinv) % p; + const gy = (y_mont * Rinv) % p; + const ref = refSumPerBucket.get(b); + let ok = false; + if (ref === null) { + ok = gx === 0n && gy === 0n; + } else if (ref) { + ok = gx === ref.x && gy === ref.y; + } + const entry: BucketCheck = { + bucket: b, + count: initCounts[b], + gpu_x: gx.toString(16), + gpu_y: gy.toString(16), + ref_x: ref ? ref.x.toString(16) : 'INF', + ref_y: ref ? ref.y.toString(16) : 'INF', + ok, + }; + checks.push(entry); + if (ok) passCount++; + else if (mismatches.length < 8) mismatches.push(entry); + } + const allPassed = mismatches.length === 0 && passCount === checks.length; + + if (allPassed) { + log('ok', `oracle PASS — ${passCount}/${checks.length} buckets match noble reference`); + } else { + log('err', `oracle FAIL — ${checks.length - passCount}/${checks.length} buckets diverged (showing first ${mismatches.length})`); + for (const m of mismatches) { + log('err', ` bucket ${m.bucket} (count=${m.count})`); + log('err', ` gpu: x=${m.gpu_x} y=${m.gpu_y}`); + log('err', ` ref: x=${m.ref_x} y=${m.ref_y}`); + } + } + + return { + n: NPTS, + buckets: BUCKETS, + s: S, + wgi: WGI, + levels: levelIdx, + total_pair_adds: totalPairAdds, + buckets_checked: checks.length, + buckets_passed: passCount, + first_mismatches: mismatches, + all_passed: allPassed, + gpu_wall_ms: gpuWall, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + if (qp.get('n')) NPTS = parseInt(qp.get('n')!, 10); + if (qp.get('buckets')) BUCKETS = parseInt(qp.get('buckets')!, 10); + if (qp.get('s')) S = parseInt(qp.get('s')!, 10); + if (qp.get('wgi')) WGI = parseInt(qp.get('wgi')!, 10); + if (qp.get('seed')) SEED = parseInt(qp.get('seed')!, 10); + return { n: NPTS, buckets: BUCKETS, s: S, wgi: WGI, seed: SEED }; +} + +async function main() { + try { + if (!('gpu' in navigator)) throw new Error('navigator.gpu missing'); + const params = parseParams(); + benchState.params = params; + log('info', `params: ${JSON.stringify(params)}`); + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + const Rinv = modInverse(R, p); + const sm = new ShaderManager(1, BUCKETS, BN254_CURVE_CONFIG, false); + const r = await runOracle(device, sm, R, Rinv, p); + benchState.results.push(r); + resultsClient.postProgress({ + kind: 'oracle_done', + all_passed: r.all_passed, + buckets_passed: r.buckets_passed, + buckets_checked: r.buckets_checked, + gpu_wall_ms: r.gpu_wall_ms, + }); + benchState.state = 'done'; + log('ok', 'done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + }) + .finally(() => { + postFinal().catch(() => {}); + }); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs index f89804e80ce6..b9c2fc8ef1e7 100644 --- a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs +++ b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs @@ -136,6 +136,7 @@ const pageMap = { "bench-msm-tree-v3": "/dev/msm-webgpu/bench-msm-tree-v3.html", "bench-planner": "/dev/msm-webgpu/bench-planner.html", "bench-csr-to-v2": "/dev/msm-webgpu/bench-csr-to-v2.html", + "bench-msm-oracle": "/dev/msm-webgpu/bench-msm-oracle.html", "bench-smvp-tree": "/dev/msm-webgpu/bench-smvp-tree.html", sanity: "/dev/msm-webgpu/index.html", }; From 299b9fcbed4cd9ac47684cfbf5312e767244c47e Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 14:11:36 +0000 Subject: [PATCH 23/33] fix(bb/msm): correct active_sums SoA layout in v2 oracle host init MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The v2 pair-tree kernels (marshal, disjoint, scatter, carry) all address active_sums as element-major-within-plane: vec4 index = plane * PG * M + PG * elem + v where PG=2 vec4 per field element. My oracle's writeElem/decodeAt were laying out v-major-within-plane (each plane's v=0 slabs first, then v=1 slabs), so the kernels read the high vec4 of x in the slot where the low vec4 of the next element should live. The pre-existing bench-msm-tree-v2 init uses the same wrong layout, but its sanity check is just "buffer is non-zero" — running on random non-curve data, garbage in → non-zero garbage out, sanity OK. The oracle (real BN254 points, noble cross-check) is the first thing that actually compared values against ground truth and caught it. After the fix the oracle passes 32/32 buckets at n=256, B=32, byte- equal to noble's projective sum per bucket. GPU wall ~8 ms (oracle sizing, not a perf measurement). bench-msm-tree-v2 has the same layout bug in its init but the perf numbers it reports (22.13 ns/in-pt single-submit) are kernel- throughput-true regardless of input meaning — they measure dispatch + kernel runtime on M2. A separate commit will port the same fix to that harness so its sanity check actually exercises curve-correct inputs. --- barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.ts | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.ts index a2ab8585ec09..2addcec51fae 100644 --- a/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.ts +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.ts @@ -135,8 +135,9 @@ function buildL0WithRealPoints( const cursor = new Uint32Array(B); const writeElem = (planeIdx: number, dstIdx: number, words: Uint32Array, srcOff: number) => { + const planeBase = planeIdx * PG * M; for (let v = 0; v < PG; v++) { - const base = ((planeIdx * PG + v) * M + dstIdx) * 4; + const base = (planeBase + PG * dstIdx + v) * 4; buf[base + 0] = words[srcOff + 4 * v + 0]; buf[base + 1] = words[srcOff + 4 * v + 1]; buf[base + 2] = words[srcOff + 4 * v + 2]; @@ -481,9 +482,11 @@ async function runOracle(device: GPUDevice, sm: ShaderManager, R: bigint, Rinv: const decodeAt = (slot: number): { x_mont: bigint; y_mont: bigint } => { const xWords = new Uint32Array(8); const yWords = new Uint32Array(8); + const planeBaseX = 0 * PG * M; + const planeBaseY = 1 * PG * M; for (let v = 0; v < PG; v++) { - const baseX = ((0 * PG + v) * M + slot) * 4; - const baseY = ((1 * PG + v) * M + slot) * 4; + const baseX = (planeBaseX + PG * slot + v) * 4; + const baseY = (planeBaseY + PG * slot + v) * 4; xWords[4 * v + 0] = result[baseX + 0]; xWords[4 * v + 1] = result[baseX + 1]; xWords[4 * v + 2] = result[baseX + 2]; From 0ace08692552efef49ac74e340f1a5d84c0c0430 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 14:28:58 +0000 Subject: [PATCH 24/33] feat(bb/msm): v2 -> production running buffers adapter + orchestrator scaffold Step 3 first cut. Replaces the cuZK schedule + batch_inverse_parallel + apply_scatter round-loop in smvp_batch_affine_gpu with a single v2 pair- tree dispatch per window. This commit ships the architecturally-correct pieces and is honest about the one runtime gap that needs to land before the orchestrator can run end-to-end. What's correct and shippable: - wgsl/cuzk/v2_to_running.template.wgsl. Boundary adapter that copies the per-bucket reduced packed point from the v2 active_sums slot at final_offsets[b] into the production running_x / running_y layout indexed by bucket_global = subtask_idx * num_columns + bucket_local, and sets bucket_active per bucket. One thread per (subtask, bucket_local); the caller binds running_x / running_y / bucket_active sub-views offset by subtask_idx * num_columns so a per-window dispatch lands the result at the right global slot. - shader_manager.ts gen_v2_to_running_shader wiring. - cuzk/smvp_v2_pair_tree.ts orchestrator scaffold. Documents the full pipeline (csr_to_v2_meta -> csr_to_v2_active_sums -> per-level planner_v2 + marshal_pairs + pair_disjoint_tree + scatter_pairs + carry_copy -> v2_to_running), the buffer layouts at every boundary, and the per-window single-submit shape. Exports maxChunksUpperBound and a `sizes` helper so the msm.ts integration can pre-allocate matching scratch buffers. What's not yet runtime-correct (and why this commit's runSmvpV2PairTree throws): - planner_v2 (ba_planner_v2_bench.template.wgsl) writes only the first numChunks * S entries of chunk_plan / scatter_plan. It does not pad- fill the tail with (padLIdx, padRIdx) / discardIdx the way the host planner in bench-msm-tree-v2 does. Dispatching marshal / disjoint / scatter at the worst-case T_upper would then read stale or zero chunk_plan entries, compute garbage affine adds, and scatter them via stale scatter_plan entries -- corrupting real bucket slots in active_sums_new. Two paths to runtime correctness, both deferred to the follow-up: (a) Extend planner_v2 to take padLIdx / padRIdx / discardIdx as uniforms and pad-fill chunk_plan / scatter_plan / carry_plan tails in a final phase. Re-validate the standalone bench-planner harness against the updated host reference, then runSmvpV2PairTree can dispatch at T_upper safely. (b) Have planner_v2 emit per-level dispatch counts and switch marshal / disjoint / scatter / carry to dispatchWorkgroupsIndirect. Option (a) is the smaller change. The scaffold's bind-group plumbing is ready for it; only the planner shader needs the extra pad-fill phase. --- .../ts/src/msm_webgpu/cuzk/shader_manager.ts | 20 ++ .../src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts | 200 ++++++++++++++++++ .../src/msm_webgpu/wgsl/_generated/shaders.ts | 66 +++++- .../wgsl/cuzk/v2_to_running.template.wgsl | 62 ++++++ 4 files changed, 347 insertions(+), 1 deletion(-) create mode 100644 barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/v2_to_running.template.wgsl diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index ab28aa1b4425..30c0b79231f9 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -42,6 +42,7 @@ import { convert_points_only as convert_points_only_shader, csr_to_v2_active_sums as csr_to_v2_active_sums_shader, csr_to_v2_meta as csr_to_v2_meta_shader, + v2_to_running as v2_to_running_shader, decompose_scalars_signed_only as decompose_scalars_signed_only_shader, decompress_g1_bn254 as decompress_g1_bn254_shader, divsteps_bench as divsteps_bench_shader, @@ -913,6 +914,25 @@ ${packLines.join('\n')} ); } + /** + * Boundary adapter for the v2 pair-tree -> production finalize: copies + * the per-bucket reduced packed point out of the v2 combined-SoA + * active_sums buffer into the production running_x / running_y layout + * and writes bucket_active. One thread per (subtask, bucket_local); + * the caller binds running_x / running_y / bucket_active sub-views + * offset by subtask_idx * num_columns. + */ + public gen_v2_to_running_shader(workgroup_size: number): string { + if (workgroup_size <= 0 || !Number.isInteger(workgroup_size)) { + throw new Error(`gen_v2_to_running_shader: workgroup_size (${workgroup_size}) must be a positive integer`); + } + return mustache.render( + v2_to_running_shader, + { workgroup_size, recompile: this.recompile }, + {}, + ); + } + /** * Layout converter (meta derivation): writes per-bucket count and * subtask-relative offset from cuZK row_ptr. One thread per diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts b/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts new file mode 100644 index 000000000000..98a8b2b96ee1 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts @@ -0,0 +1,200 @@ +/// + +/** + * v2 bin-packed pair-tree MSM bucket-accumulate orchestrator — + * step 3 of the rewrite from the cuZK round-loop to a single-submit + * pair-tree per pippenger window. + * + * Goal: a drop-in replacement for `smvp_batch_affine_gpu` (the + * schedule + batch_inverse_parallel + apply_scatter round-loop) that + * produces the same downstream contract — running_x / running_y / + * bucket_active per (subtask, bucket_local) — so the existing + * batch_affine_finalize_collect / finalize_apply / BPR / horner stages + * can consume the v2 output without any changes. + * + * Pipeline per window: + * + * csr_to_v2_meta row_ptr -> counts[B] + offsets[B] + * csr_to_v2_active_sums val_idx + cached bases (packed 8x u32) -> + * bucket-major active_sums in v2 combined SoA + * for level in 0..max_levels: + * planner_v2 counts/offsets -> chunk_plan / scatter_plan + * / carry_plan + new_counts / new_offsets + + * totals + * marshal_pairs active_sums + chunk_plan -> chain_buf + * pair_disjoint_tree chain_buf -> tempOut (S pair sums per chunk, + * single-fr_inv per chunk, lean affine add) + * scatter_pairs tempOut + scatter_plan -> active_sums_next + * carry_copy odd-count tails -> active_sums_next + * v2_to_running final active_sums slot per non-empty bucket + * -> running_x / running_y / bucket_active + * (production layout, ready for finalize) + * + * Layouts: + * active_sums (combined SoA, one buffer per ping-pong copy): + * plane 0 (x) vec4 indices [0, PG * M) + * plane 1 (y) vec4 indices [PG * M, 2 * PG * M) + * per-element layout: PG=2 vec4 at [PG*elem, PG*elem+1]. + * M = input_size + 2 (last 2 slots hold a pad pair). + * running_x / running_y (production, separate buffers): + * packed 8x u32 = 2 vec4 per (subtask, bucket_local), at + * [PG * bucket_global, PG * bucket_global + 1] with + * bucket_global = subtask_idx * num_columns + bucket_local. + * bucket_active: u32 per bucket_global. + * + * @remarks IMPLEMENTATION STATUS — Step 3 scaffolding only. + * + * `runSmvpV2PairTree` below is **not yet runtime-correct** because the + * planner_v2 shader (`ba_planner_v2_bench.template.wgsl`) writes only + * the first `numChunks * S` entries of chunk_plan / scatter_plan — it + * does not pad-fill the tail with (padLIdx, padRIdx) / discardIdx the + * way the host planner in bench-msm-tree-v2 does. Dispatching + * marshal_pairs / pair_disjoint_tree / scatter_pairs at the worst-case + * `T_upper` (the buffer's allocated chunk count) would then read stale + * / zero entries from chunk_plan, compute garbage affine adds, and + * scatter the garbage into real bucket slots via stale scatter_plan + * entries — corrupting the result. + * + * Two correct paths to land next: + * (a) Extend planner_v2 to take padLIdx / padRIdx / discardIdx + * uniforms and pad-fill chunk_plan / scatter_plan / carry_plan + * tails in a final phase. Re-validate the standalone bench- + * planner harness against an updated host reference, then this + * orchestrator can dispatch at T_upper safely. + * (b) Have planner_v2 write per-level dispatch counts (numChunks / + * numCarries derived from totals) into a small dispatch_args + * buffer, and switch marshal / disjoint / scatter / carry to + * `dispatchWorkgroupsIndirect`. Avoids the pad-fill but needs a + * per-level uniform-vs-storage rewrite for the T-and-N + * parameters that those four kernels currently read from + * `var`. + * + * Option (a) is the simpler change. The scaffolding below records + * pipeline compiles + bind-group construction so the eventual runtime + * is a small delta once planner_v2 pad-fill lands. + * + * The companion `v2_to_running` shader (`v2_to_running.template.wgsl`) + * is finished and correct: it copies the final per-bucket reduced + * packed point from the v2 active_sums slot into the production + * running_x / running_y / bucket_active layout at the correct + * bucket_global. Its bindings allow per-subtask views (offset by + * subtask_idx * num_columns) so a single per-window dispatch lands the + * result in the right slab of the global running buffers. + */ + +import { ShaderManager } from './shader_manager.js'; + +const PG = 2; + +export interface SmvpV2PairTreeOptions { + device: GPUDevice; + shaderManager: ShaderManager; + num_subtasks: number; + num_columns: number; + input_size: number; + + s?: number; + tpb?: number; + per_thread?: number; + wgi?: number; + max_levels?: number; + + /** Per-subtask CSR row_ptr layout from cuZK transpose. */ + val_idx_buf: GPUBuffer; + /** Per-subtask CSR row_ptr (num_columns + 1 entries per subtask). */ + row_ptr_buf: GPUBuffer; + /** Packed cached_bases.point_x_sb (input_size * 32 bytes). */ + point_x_buf: GPUBuffer; + /** Packed cached_bases.point_y_sb (input_size * 32 bytes). */ + point_y_buf: GPUBuffer; + + /** + * Output: running_x / running_y per bucket_global, packed 8x u32. + * Sized num_subtasks * num_columns * 32 bytes each. + */ + running_x_buf: GPUBuffer; + running_y_buf: GPUBuffer; + /** Output: bucket_active per bucket_global, u32. */ + bucket_active_buf: GPUBuffer; +} + +export interface SmvpV2PairTreeStats { + levels_per_window: number; + pipelines_compiled: number; + bind_groups_recorded: number; +} + +/** + * Construct the v2 bucket-accumulate dispatch chain. + * + * @throws Always — runtime is gated on the planner_v2 pad-fill + * follow-up described in this module's docstring. + */ +export async function runSmvpV2PairTree( + _opts: SmvpV2PairTreeOptions, +): Promise { + throw new Error( + 'smvp_v2_pair_tree: orchestrator scaffolding is checked in but ' + + 'runtime is gated on planner_v2 pad-fill (option a) or indirect ' + + 'dispatch (option b). See module docstring.', + ); +} + +/** + * Reference upper bound on the chunk count any level can produce, used + * by the orchestrator to size chunk_plan / scatter_plan / chain_buf / + * tempOut at the worst case (level 0 with all pairs). + * + * Per-bucket count C, total active points N (sum of counts), per-level + * pair count is bounded by floor(N / 2). After bin-packing into chunks + * of S, numChunks <= ceil(N / 2 / S). Plus a +num_columns slack for the + * carry-forward elements that bump some buckets at the next level. + */ +export function maxChunksUpperBound(input_size: number, num_columns: number, s: number): number { + return Math.max(1, Math.ceil(input_size / 2 / s) + num_columns); +} + +/** + * Buffer-byte-size helpers — kept here so the production msm.ts + * integration can pre-allocate matching scratch when wiring v2 in + * behind a flag. + */ +export const sizes = { + /** Combined-SoA active_sums byte size for one window, including the pad pair. */ + activeSumsBytes(input_size: number): number { + const M = input_size + 2; + return 2 * PG * M * 16; + }, + /** chain_buf byte size for one window. */ + chainBufBytes(input_size: number, num_columns: number, s: number): number { + const T = maxChunksUpperBound(input_size, num_columns, s); + return 2 * PG * (2 * s * T) * 16; + }, + /** tempOut byte size for one window. */ + tempOutBytes(input_size: number, num_columns: number, s: number): number { + const T = maxChunksUpperBound(input_size, num_columns, s); + return 2 * PG * (s * T) * 16; + }, + /** chunk_plan byte size per level. */ + chunkPlanBytes(input_size: number, num_columns: number, s: number): number { + const T = maxChunksUpperBound(input_size, num_columns, s); + return 2 * s * T * 4; + }, + /** scatter_plan byte size per level. */ + scatterPlanBytes(input_size: number, num_columns: number, s: number): number { + const T = maxChunksUpperBound(input_size, num_columns, s); + return s * T * 4; + }, + /** carry_plan byte size per level. */ + carryPlanBytes(num_columns: number): number { + return 2 * num_columns * 4; + }, + /** counts byte size per level. */ + countsBytes(num_columns: number): number { + return num_columns * 4; + }, + /** offsets byte size per level. */ + offsetsBytes(num_columns: number): number { + return num_columns * 4; + }, +}; diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index 60ee225eba91..f9869415faaf 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1,6 +1,6 @@ // AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT. // Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate. -// 64 shader sources inlined. +// 65 shader sources inlined. /* eslint-disable */ @@ -8624,6 +8624,70 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { {{{ recompile }}} }`; +export const v2_to_running = `// Boundary adapter from the v2 bin-packed pair-tree's per-window +// active_sums buffer (combined SoA, plane 0 = X / plane 1 = Y at vec4 +// indices [PG*elem + v]) to the production running_x / running_y / +// bucket_active layout that batch_affine_finalize_collect consumes. +// +// Per-window dispatch: one thread per (subtask, bucket_local). The +// caller binds the per-window active_sums (combined SoA), the final +// counts and offsets emitted by the planner's last level, and views of +// the global running_x / running_y / bucket_active arrays offset by +// subtask_idx * num_columns so a single bucket_global is addressable +// via gid.x. +// +// For non-empty buckets the v2 pair-tree has reduced the bucket to one +// packed-Montgomery point sitting at active_sums[final_offsets[b]] in +// the input plane layout. We copy that element into running_x / +// running_y at the matching bucket_global slot (packed 8x u32 = two +// vec4 per element, same layout production already uses when packed). +// Empty buckets only set bucket_active = 0 — running_x / running_y are +// left untouched; finalize is gated on bucket_active and never reads +// the unwritten slot. + +const PG: u32 = 2u; + +@group(0) @binding(0) var active_sums: array>; +@group(0) @binding(1) var final_counts: array; +@group(0) @binding(2) var final_offsets: array; +@group(0) @binding(3) var running_x: array>; +@group(0) @binding(4) var running_y: array>; +@group(0) @binding(5) var bucket_active: array; +@group(0) @binding(6) var params: vec4; +// params.x = num_columns (active per-window bucket count) +// params.y = M (elements per plane in the v2 active_sums buffer) + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let bucket_local = gid.x; + let num_columns = params.x; + let M = params.y; + if (bucket_local >= num_columns) { + return; + } + + let count = final_counts[bucket_local]; + if (count == 0u) { + bucket_active[bucket_local] = 0u; + return; + } + + bucket_active[bucket_local] = 1u; + + let slot = final_offsets[bucket_local]; + let plane_x_base = PG * slot; + let plane_y_base = PG * M + PG * slot; + let dst = PG * bucket_local; + + running_x[dst + 0u] = active_sums[plane_x_base + 0u]; + running_x[dst + 1u] = active_sums[plane_x_base + 1u]; + running_y[dst + 0u] = active_sums[plane_y_base + 0u]; + running_y[dst + 1u] = active_sums[plane_y_base + 1u]; + + {{{ recompile }}} +} +`; + export const by_inverse = `// Bernstein-Yang safegcd inversion for the BN254 base field, WGSL port. // // This file will grow over sub-steps 1.3-1.5 of the WebGPU MSM rewrite plan diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/v2_to_running.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/v2_to_running.template.wgsl new file mode 100644 index 000000000000..cadcd5043753 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/v2_to_running.template.wgsl @@ -0,0 +1,62 @@ +// Boundary adapter from the v2 bin-packed pair-tree's per-window +// active_sums buffer (combined SoA, plane 0 = X / plane 1 = Y at vec4 +// indices [PG*elem + v]) to the production running_x / running_y / +// bucket_active layout that batch_affine_finalize_collect consumes. +// +// Per-window dispatch: one thread per (subtask, bucket_local). The +// caller binds the per-window active_sums (combined SoA), the final +// counts and offsets emitted by the planner's last level, and views of +// the global running_x / running_y / bucket_active arrays offset by +// subtask_idx * num_columns so a single bucket_global is addressable +// via gid.x. +// +// For non-empty buckets the v2 pair-tree has reduced the bucket to one +// packed-Montgomery point sitting at active_sums[final_offsets[b]] in +// the input plane layout. We copy that element into running_x / +// running_y at the matching bucket_global slot (packed 8x u32 = two +// vec4 per element, same layout production already uses when packed). +// Empty buckets only set bucket_active = 0 — running_x / running_y are +// left untouched; finalize is gated on bucket_active and never reads +// the unwritten slot. + +const PG: u32 = 2u; + +@group(0) @binding(0) var active_sums: array>; +@group(0) @binding(1) var final_counts: array; +@group(0) @binding(2) var final_offsets: array; +@group(0) @binding(3) var running_x: array>; +@group(0) @binding(4) var running_y: array>; +@group(0) @binding(5) var bucket_active: array; +@group(0) @binding(6) var params: vec4; +// params.x = num_columns (active per-window bucket count) +// params.y = M (elements per plane in the v2 active_sums buffer) + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let bucket_local = gid.x; + let num_columns = params.x; + let M = params.y; + if (bucket_local >= num_columns) { + return; + } + + let count = final_counts[bucket_local]; + if (count == 0u) { + bucket_active[bucket_local] = 0u; + return; + } + + bucket_active[bucket_local] = 1u; + + let slot = final_offsets[bucket_local]; + let plane_x_base = PG * slot; + let plane_y_base = PG * M + PG * slot; + let dst = PG * bucket_local; + + running_x[dst + 0u] = active_sums[plane_x_base + 0u]; + running_x[dst + 1u] = active_sums[plane_x_base + 1u]; + running_y[dst + 0u] = active_sums[plane_y_base + 0u]; + running_y[dst + 1u] = active_sums[plane_y_base + 1u]; + + {{{ recompile }}} +} From 5b5ffabbcbf279b124debbc1bd55fb03e944d40a Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 14:42:42 +0000 Subject: [PATCH 25/33] feat(bb/msm): v2 prod kernels (indirect dispatch) + runnable orchestrator Picks the indirect-dispatch path over chunk_plan pad-fill. The pad-fill alternative would have dispatched marshal / pair_disjoint / scatter at the worst-case T (chunk_plan capacity) and computed an affine add on every chunk including the pad ones, scattered to the discard slot. For BN254 production sizing (N ~= 2^16, B ~= 32k, S = 16) the worst level has ~150 pad chunks; later levels approach all-pad as the bucket counts collapse. Estimated waste: ~30-50 ms of GPU compute per MSM call. With indirect dispatch off the planner's per-level totals, EXACTLY num_chunks workgroups run per level -- zero pad-chunk waste at runtime. Shaders (5 new): - ba_planner_v2_prod.template.wgsl: ba_planner_v2_bench with totals extended by the per-level dispatch_args triples that drive the four downstream prod kernels via dispatchWorkgroupsIndirect. totals[0..2] = (total_pairs, total_carries, total_new) (unchanged bench-compat values) totals[3] = num_chunks = ceil(total_pairs / S) totals[4..6] = marshal/disjoint/scatter dispatch X (ceil(num_chunks / WGI), 1, 1) totals[7..9] = carry dispatch X (ceil(total_carries / WGI), 1, 1) WGI is a compile-time constant matching the downstream prod kernels' workgroup_size. - ba_marshal_pairs_prod.template.wgsl: marshal with T read from totals[3] (storage), M_in from a tiny consts uniform. - ba_pair_disjoint_tree_prod.template.wgsl: disjoint pair-sum tree with T_curr read from totals[3] (storage); always uses the final-mode strided write so it pairs with ba_scatter_pairs_prod. - ba_scatter_pairs_prod.template.wgsl: scatter with T read from totals[3] (storage), M_new from consts. - ba_carry_copy_prod.template.wgsl: carry copy with num_carries read from totals[1] (storage), M_old/M_new from consts. Wired into shader_manager.ts via gen_ba_planner_v2_prod_shader, gen_ba_marshal_pairs_prod_shader, gen_ba_pair_disjoint_tree_prod_shader, gen_ba_scatter_pairs_prod_shader, gen_ba_carry_copy_prod_shader. Orchestrator (`cuzk/smvp_v2_pair_tree.ts`) is now runnable. Per call: compile the 8 pipelines (csr_to_v2_meta, csr_to_v2_active_sums, planner_v2_prod, marshal_prod, disjoint_prod, scatter_prod, carry_prod, v2_to_running), allocate the per-window scratch and per-level plan + totals buffers (totals has STORAGE | INDIRECT | COPY_DST usage), seed the pad pair in both ping-pong active_sums, then for each window record csr meta + active_sums fill + max_levels x (planner direct + marshal/disjoint/scatter/carry indirect dispatch via totals offsets 16 and 28) + v2_to_running. Single encoder, single queue.submit per call. The runtime is now wired but unvalidated end-to-end; the next commit adds the noble-CPU oracle harness on top of runSmvpV2PairTree and exercises it on BrowserStack M2. --- .../ts/src/msm_webgpu/cuzk/shader_manager.ts | 110 +++ .../src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts | 625 ++++++++++++++---- .../src/msm_webgpu/wgsl/_generated/shaders.ts | 458 ++++++++++++- .../cuzk/ba_carry_copy_prod.template.wgsl | 44 ++ .../cuzk/ba_marshal_pairs_prod.template.wgsl | 65 ++ .../ba_pair_disjoint_tree_prod.template.wgsl | 119 ++++ .../cuzk/ba_planner_v2_prod.template.wgsl | 170 +++++ .../cuzk/ba_scatter_pairs_prod.template.wgsl | 48 ++ 8 files changed, 1522 insertions(+), 117 deletions(-) create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_carry_copy_prod.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_pairs_prod.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_tree_prod.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_prod.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_scatter_pairs_prod.template.wgsl diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index 30c0b79231f9..24bd0a2f4ba1 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -14,6 +14,11 @@ import { ba_pair_disjoint_tree_bench as ba_pair_disjoint_tree_bench_shader, ba_planner_bench as ba_planner_bench_shader, ba_planner_v2_bench as ba_planner_v2_bench_shader, + ba_planner_v2_prod as ba_planner_v2_prod_shader, + ba_marshal_pairs_prod as ba_marshal_pairs_prod_shader, + ba_pair_disjoint_tree_prod as ba_pair_disjoint_tree_prod_shader, + ba_scatter_pairs_prod as ba_scatter_pairs_prod_shader, + ba_carry_copy_prod as ba_carry_copy_prod_shader, ba_scatter_pairs_bench as ba_scatter_pairs_bench_shader, ba_tail_reduce_bench as ba_tail_reduce_bench_shader, ba_rev_packed_carry_bench as ba_rev_packed_carry_bench_shader, @@ -914,6 +919,111 @@ ${packLines.join('\n')} ); } + /** + * Production v2 GPU planner. Same algorithm as + * gen_ba_planner_v2_bench_shader, additionally emits per-level + * dispatch_args triples into totals[4..6] (marshal/disjoint/scatter) + * and totals[7..9] (carry) so the host orchestrator can drive the + * four downstream prod kernels via dispatchWorkgroupsIndirect with + * zero pad-chunk waste. wgi must match the workgroup size used to + * compile the four downstream prod kernels. + */ + public gen_ba_planner_v2_prod_shader(workgroup_size: number, per_thread: number, s: number, wgi: number, pair_cap: number = 64): string { + if (workgroup_size <= 0 || per_thread <= 0 || s <= 0 || wgi <= 0 || pair_cap <= 0 || + !Number.isInteger(workgroup_size) || !Number.isInteger(per_thread) || !Number.isInteger(s) || !Number.isInteger(wgi) || !Number.isInteger(pair_cap)) { + throw new Error(`gen_ba_planner_v2_prod_shader: positive integer args required`); + } + return mustache.render( + ba_planner_v2_prod_shader, + { workgroup_size, per_thread, pair_cap, s, wgi, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Marshal pairs — prod variant. Reads num_chunks from + * totals[3] (storage), dispatched indirectly off totals[4..6]. + */ + public gen_ba_marshal_pairs_prod_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_marshal_pairs_prod_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + return mustache.render( + ba_marshal_pairs_prod_shader, + { workgroup_size, s, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Disjoint pair-sum tree — prod variant. Reads T_curr from + * totals[3] (storage), always uses the final-mode strided write. + */ + public gen_ba_pair_disjoint_tree_prod_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_pair_disjoint_tree_prod_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + const dec = this.decoupledPackUnpackWgsl(); + return mustache.render( + ba_pair_disjoint_tree_prod_shader, + { + workgroup_size, + s, + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, + dec_unpack: dec.unpack, + dec_pack: dec.pack, + recompile: this.recompile, + }, + { + structs, + bigint_funcs, + montgomery_product_funcs: this.mont_product_src, + field_funcs, + fr_pow_funcs, + bigint_by_funcs, + by_inverse_a_funcs, + }, + ); + } + + /** + * Scatter pairs — prod variant. Reads T from totals[3] (storage). + */ + public gen_ba_scatter_pairs_prod_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_scatter_pairs_prod_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + return mustache.render( + ba_scatter_pairs_prod_shader, + { workgroup_size, s, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Carry copy — prod variant. Reads num_carries from totals[1] (storage). + */ + public gen_ba_carry_copy_prod_shader(workgroup_size: number): string { + if (workgroup_size <= 0 || !Number.isInteger(workgroup_size)) { + throw new Error(`gen_ba_carry_copy_prod_shader: workgroup_size (${workgroup_size}) must be a positive integer`); + } + return mustache.render( + ba_carry_copy_prod_shader, + { workgroup_size, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + /** * Boundary adapter for the v2 pair-tree -> production finalize: copies * the per-bucket reduced packed point out of the v2 combined-SoA diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts b/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts index 98a8b2b96ee1..de9979bd1237 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts @@ -1,90 +1,65 @@ /// /** - * v2 bin-packed pair-tree MSM bucket-accumulate orchestrator — - * step 3 of the rewrite from the cuZK round-loop to a single-submit - * pair-tree per pippenger window. + * v2 bin-packed pair-tree MSM bucket-accumulate orchestrator. * - * Goal: a drop-in replacement for `smvp_batch_affine_gpu` (the - * schedule + batch_inverse_parallel + apply_scatter round-loop) that - * produces the same downstream contract — running_x / running_y / - * bucket_active per (subtask, bucket_local) — so the existing - * batch_affine_finalize_collect / finalize_apply / BPR / horner stages - * can consume the v2 output without any changes. + * Drop-in replacement for the cuZK round-loop (`smvp_batch_affine_gpu`'s + * schedule + batch_inverse_parallel + apply_scatter per round) for the + * Pippenger bucket-accumulate phase. For each window: * - * Pipeline per window: - * - * csr_to_v2_meta row_ptr -> counts[B] + offsets[B] - * csr_to_v2_active_sums val_idx + cached bases (packed 8x u32) -> - * bucket-major active_sums in v2 combined SoA + * csr_to_v2_meta row_ptr -> per-bucket count + offset + * csr_to_v2_active_sums val_idx + cached bases -> bucket-major + * active_sums (combined SoA, packed 8x u32) * for level in 0..max_levels: - * planner_v2 counts/offsets -> chunk_plan / scatter_plan - * / carry_plan + new_counts / new_offsets + - * totals - * marshal_pairs active_sums + chunk_plan -> chain_buf - * pair_disjoint_tree chain_buf -> tempOut (S pair sums per chunk, - * single-fr_inv per chunk, lean affine add) - * scatter_pairs tempOut + scatter_plan -> active_sums_next - * carry_copy odd-count tails -> active_sums_next - * v2_to_running final active_sums slot per non-empty bucket - * -> running_x / running_y / bucket_active - * (production layout, ready for finalize) + * ba_planner_v2_prod counts/offsets -> chunk_plan / scatter_plan + * / carry_plan + new_counts/new_offsets + + * totals (incl. per-level dispatch_args + * triples at totals[4..6] and totals[7..9]) + * ba_marshal_pairs_prod indirect dispatch off totals[4..6] + * ba_pair_disjoint_tree_prod indirect dispatch off totals[4..6] + * ba_scatter_pairs_prod indirect dispatch off totals[4..6] + * ba_carry_copy_prod indirect dispatch off totals[7..9] + * v2_to_running final active_sums slot per non-empty + * bucket -> running_x / running_y / + * bucket_active in production layout + * + * The prod kernels read num_chunks (= ceil(total_pairs / S)) and + * num_carries from the planner's totals storage output; the host + * dispatches them via dispatchWorkgroupsIndirect. Each level's + * downstream dispatch is sized to EXACTLY the chunks/carries the + * planner produced — zero pad-chunk waste, the runtime advantage + * over the pad-fill alternative. * - * Layouts: + * All dispatches for all windows are recorded onto one command encoder + * and submitted once. Submit overhead is paid once per MSM, not once + * per window or once per level. + * + * Layout boundaries: * active_sums (combined SoA, one buffer per ping-pong copy): - * plane 0 (x) vec4 indices [0, PG * M) - * plane 1 (y) vec4 indices [PG * M, 2 * PG * M) - * per-element layout: PG=2 vec4 at [PG*elem, PG*elem+1]. - * M = input_size + 2 (last 2 slots hold a pad pair). + * plane 0 (x) at vec4 indices [0, PG * M) + * plane 1 (y) at vec4 indices [PG * M, 2 * PG * M) + * element layout: PG=2 vec4 at [PG*elem, PG*elem+1] + * M = input_size + 2 (last 2 slots are the pad pair the planner + * emits into the chunk-tail for filler pairs — we initialise the + * pad pair once at orchestrator start with distinct-x Montgomery- + * form values so the disjoint kernel's lean affine add is well- + * defined on pad chunks even though they get scattered to the + * discard slot) * running_x / running_y (production, separate buffers): * packed 8x u32 = 2 vec4 per (subtask, bucket_local), at * [PG * bucket_global, PG * bucket_global + 1] with - * bucket_global = subtask_idx * num_columns + bucket_local. - * bucket_active: u32 per bucket_global. - * - * @remarks IMPLEMENTATION STATUS — Step 3 scaffolding only. - * - * `runSmvpV2PairTree` below is **not yet runtime-correct** because the - * planner_v2 shader (`ba_planner_v2_bench.template.wgsl`) writes only - * the first `numChunks * S` entries of chunk_plan / scatter_plan — it - * does not pad-fill the tail with (padLIdx, padRIdx) / discardIdx the - * way the host planner in bench-msm-tree-v2 does. Dispatching - * marshal_pairs / pair_disjoint_tree / scatter_pairs at the worst-case - * `T_upper` (the buffer's allocated chunk count) would then read stale - * / zero entries from chunk_plan, compute garbage affine adds, and - * scatter the garbage into real bucket slots via stale scatter_plan - * entries — corrupting the result. - * - * Two correct paths to land next: - * (a) Extend planner_v2 to take padLIdx / padRIdx / discardIdx - * uniforms and pad-fill chunk_plan / scatter_plan / carry_plan - * tails in a final phase. Re-validate the standalone bench- - * planner harness against an updated host reference, then this - * orchestrator can dispatch at T_upper safely. - * (b) Have planner_v2 write per-level dispatch counts (numChunks / - * numCarries derived from totals) into a small dispatch_args - * buffer, and switch marshal / disjoint / scatter / carry to - * `dispatchWorkgroupsIndirect`. Avoids the pad-fill but needs a - * per-level uniform-vs-storage rewrite for the T-and-N - * parameters that those four kernels currently read from - * `var`. - * - * Option (a) is the simpler change. The scaffolding below records - * pipeline compiles + bind-group construction so the eventual runtime - * is a small delta once planner_v2 pad-fill lands. - * - * The companion `v2_to_running` shader (`v2_to_running.template.wgsl`) - * is finished and correct: it copies the final per-bucket reduced - * packed point from the v2 active_sums slot into the production - * running_x / running_y / bucket_active layout at the correct - * bucket_global. Its bindings allow per-subtask views (offset by - * subtask_idx * num_columns) so a single per-window dispatch lands the - * result in the right slab of the global running buffers. + * bucket_global = subtask_idx * num_columns + bucket_local + * bucket_active: u32 per bucket_global + * v2_to_running binds running_x / running_y / bucket_active with a + * subtask_idx * num_columns byte offset so a single per-window + * dispatch lands the result at the right slab. */ import { ShaderManager } from './shader_manager.js'; const PG = 2; +const PG_VEC4_BYTES = 16; +const ELEMENT_BYTES = PG * PG_VEC4_BYTES; export interface SmvpV2PairTreeOptions { device: GPUDevice; @@ -99,101 +74,519 @@ export interface SmvpV2PairTreeOptions { wgi?: number; max_levels?: number; - /** Per-subtask CSR row_ptr layout from cuZK transpose. */ val_idx_buf: GPUBuffer; - /** Per-subtask CSR row_ptr (num_columns + 1 entries per subtask). */ row_ptr_buf: GPUBuffer; - /** Packed cached_bases.point_x_sb (input_size * 32 bytes). */ point_x_buf: GPUBuffer; - /** Packed cached_bases.point_y_sb (input_size * 32 bytes). */ point_y_buf: GPUBuffer; - /** - * Output: running_x / running_y per bucket_global, packed 8x u32. - * Sized num_subtasks * num_columns * 32 bytes each. - */ running_x_buf: GPUBuffer; running_y_buf: GPUBuffer; - /** Output: bucket_active per bucket_global, u32. */ bucket_active_buf: GPUBuffer; } export interface SmvpV2PairTreeStats { levels_per_window: number; - pipelines_compiled: number; - bind_groups_recorded: number; + num_subtasks: number; + num_columns: number; + total_passes: number; + gpu_wall_ms: number; } -/** - * Construct the v2 bucket-accumulate dispatch chain. - * - * @throws Always — runtime is gated on the planner_v2 pad-fill - * follow-up described in this module's docstring. - */ -export async function runSmvpV2PairTree( - _opts: SmvpV2PairTreeOptions, -): Promise { - throw new Error( - 'smvp_v2_pair_tree: orchestrator scaffolding is checked in but ' + - 'runtime is gated on planner_v2 pad-fill (option a) or indirect ' + - 'dispatch (option b). See module docstring.', - ); +function roStorageEntry(binding: number): GPUBindGroupLayoutEntry { + return { binding, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }; +} +function rwStorageEntry(binding: number): GPUBindGroupLayoutEntry { + return { binding, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }; +} +function uniformEntry(binding: number): GPUBindGroupLayoutEntry { + return { binding, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }; +} + +async function compilePipeline( + device: GPUDevice, + layout: GPUBindGroupLayout, + code: string, + key: string, +): Promise { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + const errs: string[] = []; + for (const m of info.messages) { + const line = `[smvp-v2 ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`; + if (m.type === 'error') { + console.error(line); + errs.push(line); + } else { + console.warn(line); + } + } + if (errs.length > 0) { + throw new Error(`WGSL compile failed for ${key}: ${errs.slice(0, 4).join(' | ')}`); + } + return device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); +} + +interface Pipelines { + csrMeta: GPUComputePipeline; + csrActive: GPUComputePipeline; + planner: GPUComputePipeline; + marshal: GPUComputePipeline; + disjoint: GPUComputePipeline; + scatter: GPUComputePipeline; + carry: GPUComputePipeline; + v2ToRunning: GPUComputePipeline; + layouts: { + meta: GPUBindGroupLayout; + active: GPUBindGroupLayout; + planner: GPUBindGroupLayout; + marshal: GPUBindGroupLayout; + disjoint: GPUBindGroupLayout; + scatter: GPUBindGroupLayout; + carry: GPUBindGroupLayout; + v2Run: GPUBindGroupLayout; + }; +} + +async function compileAll( + device: GPUDevice, + sm: ShaderManager, + wgi: number, + s: number, + tpb: number, + per_thread: number, +): Promise { + const layouts: Pipelines['layouts'] = { + meta: device.createBindGroupLayout({ + entries: [roStorageEntry(0), rwStorageEntry(1), rwStorageEntry(2), uniformEntry(3)], + }), + active: device.createBindGroupLayout({ + entries: [ + roStorageEntry(0), + roStorageEntry(1), + roStorageEntry(2), + rwStorageEntry(3), + rwStorageEntry(4), + uniformEntry(5), + ], + }), + planner: device.createBindGroupLayout({ + entries: [ + roStorageEntry(0), + roStorageEntry(1), + rwStorageEntry(2), + rwStorageEntry(3), + rwStorageEntry(4), + rwStorageEntry(5), + rwStorageEntry(6), + rwStorageEntry(7), + uniformEntry(8), + ], + }), + marshal: device.createBindGroupLayout({ + entries: [roStorageEntry(0), roStorageEntry(1), rwStorageEntry(2), roStorageEntry(3), uniformEntry(4)], + }), + disjoint: device.createBindGroupLayout({ + entries: [roStorageEntry(0), roStorageEntry(1), rwStorageEntry(2), roStorageEntry(3)], + }), + scatter: device.createBindGroupLayout({ + entries: [roStorageEntry(0), roStorageEntry(1), rwStorageEntry(2), roStorageEntry(3), uniformEntry(4)], + }), + carry: device.createBindGroupLayout({ + entries: [roStorageEntry(0), roStorageEntry(1), rwStorageEntry(2), roStorageEntry(3), uniformEntry(4)], + }), + v2Run: device.createBindGroupLayout({ + entries: [ + roStorageEntry(0), + roStorageEntry(1), + roStorageEntry(2), + rwStorageEntry(3), + rwStorageEntry(4), + rwStorageEntry(5), + uniformEntry(6), + ], + }), + }; + + const [csrMeta, csrActive, planner, marshal, disjoint, scatter, carry, v2ToRunning] = await Promise.all([ + compilePipeline(device, layouts.meta, sm.gen_csr_to_v2_meta_shader(wgi), `csr-meta-wg${wgi}`), + compilePipeline(device, layouts.active, sm.gen_csr_to_v2_active_sums_shader(wgi), `csr-active-wg${wgi}`), + compilePipeline( + device, + layouts.planner, + sm.gen_ba_planner_v2_prod_shader(tpb, per_thread, s, wgi, 64), + `planner-v2-prod-T${tpb}-P${per_thread}-S${s}-W${wgi}`, + ), + compilePipeline(device, layouts.marshal, sm.gen_ba_marshal_pairs_prod_shader(wgi, s), `marshal-prod-W${wgi}-S${s}`), + compilePipeline(device, layouts.disjoint, sm.gen_ba_pair_disjoint_tree_prod_shader(wgi, s), `disjoint-prod-W${wgi}-S${s}`), + compilePipeline(device, layouts.scatter, sm.gen_ba_scatter_pairs_prod_shader(wgi, s), `scatter-prod-W${wgi}-S${s}`), + compilePipeline(device, layouts.carry, sm.gen_ba_carry_copy_prod_shader(wgi), `carry-prod-W${wgi}`), + compilePipeline(device, layouts.v2Run, sm.gen_v2_to_running_shader(wgi), `v2-to-running-wg${wgi}`), + ]); + return { csrMeta, csrActive, planner, marshal, disjoint, scatter, carry, v2ToRunning, layouts }; +} + +interface Scratch { + activeA: GPUBuffer; + activeB: GPUBuffer; + chainBuf: GPUBuffer; + tempOut: GPUBuffer; + countsA: GPUBuffer; + countsB: GPUBuffer; + offsetsA: GPUBuffer; + offsetsB: GPUBuffer; + perLevelChunkPlan: GPUBuffer[]; + perLevelScatterPlan: GPUBuffer[]; + perLevelCarryPlan: GPUBuffer[]; + perLevelTotals: GPUBuffer[]; + metaParams: GPUBuffer; + activeParams: GPUBuffer; + plannerParams: GPUBuffer; + marshalConsts: GPUBuffer; + scatterConsts: GPUBuffer; + carryConsts: GPUBuffer; + v2RunParams: GPUBuffer; + M: number; + maxChunks: number; + perThread: number; +} + +function allocScratch( + device: GPUDevice, + num_columns: number, + input_size: number, + s: number, + max_levels: number, + tpb: number, + per_thread: number, +): Scratch { + const M = input_size + 2; + const maxChunks = Math.max(1, Math.ceil(input_size / 2 / s) + 1); + + const mk = (bytes: number, extra: GPUBufferUsageFlags = 0): GPUBuffer => + device.createBuffer({ size: bytes, usage: GPUBufferUsage.STORAGE | extra }); + + const activeBytes = 2 * PG * M * PG_VEC4_BYTES; + const activeA = mk(activeBytes, GPUBufferUsage.COPY_DST); + const activeB = mk(activeBytes, GPUBufferUsage.COPY_DST); + + const chainBuf = mk(2 * PG * (2 * s * maxChunks) * PG_VEC4_BYTES); + const tempOut = mk(2 * PG * (s * maxChunks) * PG_VEC4_BYTES); + + const countsBytes = num_columns * 4; + const offsetsBytes = num_columns * 4; + const countsA = mk(countsBytes); + const countsB = mk(countsBytes); + const offsetsA = mk(offsetsBytes); + const offsetsB = mk(offsetsBytes); + + const perLevelChunkPlan: GPUBuffer[] = []; + const perLevelScatterPlan: GPUBuffer[] = []; + const perLevelCarryPlan: GPUBuffer[] = []; + const perLevelTotals: GPUBuffer[] = []; + const chunkPlanBytes = 2 * s * maxChunks * 4; + const scatterPlanBytes = s * maxChunks * 4; + const carryPlanBytes = 2 * num_columns * 4; + const totalsBytes = Math.max(40, Math.ceil(40 / 16) * 16); + for (let lvl = 0; lvl < max_levels; lvl++) { + perLevelChunkPlan.push(mk(chunkPlanBytes)); + perLevelScatterPlan.push(mk(scatterPlanBytes)); + perLevelCarryPlan.push(mk(carryPlanBytes)); + perLevelTotals.push(mk(totalsBytes, GPUBufferUsage.INDIRECT | GPUBufferUsage.COPY_DST)); + } + + const ub = (bytes: number): GPUBuffer => + device.createBuffer({ size: Math.max(16, bytes), usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + const metaParams = ub(16); + const activeParams = ub(16); + const plannerParams = ub(16); + const marshalConsts = ub(16); + const scatterConsts = ub(16); + const carryConsts = ub(16); + const v2RunParams = ub(16); + + return { + activeA, activeB, chainBuf, tempOut, + countsA, countsB, offsetsA, offsetsB, + perLevelChunkPlan, perLevelScatterPlan, perLevelCarryPlan, perLevelTotals, + metaParams, activeParams, plannerParams, + marshalConsts, scatterConsts, carryConsts, v2RunParams, + M, maxChunks, perThread: per_thread, + }; +} + +function destroyScratch(scratch: Scratch): void { + scratch.activeA.destroy(); + scratch.activeB.destroy(); + scratch.chainBuf.destroy(); + scratch.tempOut.destroy(); + scratch.countsA.destroy(); + scratch.countsB.destroy(); + scratch.offsetsA.destroy(); + scratch.offsetsB.destroy(); + for (const b of scratch.perLevelChunkPlan) b.destroy(); + for (const b of scratch.perLevelScatterPlan) b.destroy(); + for (const b of scratch.perLevelCarryPlan) b.destroy(); + for (const b of scratch.perLevelTotals) b.destroy(); + scratch.metaParams.destroy(); + scratch.activeParams.destroy(); + scratch.plannerParams.destroy(); + scratch.marshalConsts.destroy(); + scratch.scatterConsts.destroy(); + scratch.carryConsts.destroy(); + scratch.v2RunParams.destroy(); +} + +function buildPadPair(M: number): Uint32Array { + const padPair = new Uint32Array(2 * PG * 2 * 4); + for (let i = 0; i < padPair.length; i++) { + padPair[i] = (0x9e3779b9 * (i + 1)) >>> 0; + } + if (padPair[0] === padPair[PG * 4]) padPair[PG * 4] ^= 1; + return padPair; } /** - * Reference upper bound on the chunk count any level can produce, used - * by the orchestrator to size chunk_plan / scatter_plan / chain_buf / - * tempOut at the worst case (level 0 with all pairs). + * Run the v2 pair-tree MSM bucket-accumulate for ALL pippenger windows + * in a single GPU submit. * - * Per-bucket count C, total active points N (sum of counts), per-level - * pair count is bounded by floor(N / 2). After bin-packing into chunks - * of S, numChunks <= ceil(N / 2 / S). Plus a +num_columns slack for the - * carry-forward elements that bump some buckets at the next level. + * On return the caller's running_x / running_y / bucket_active buffers + * hold each bucket's reduced packed point (or 0/inactive marker) ready + * for batch_affine_finalize_collect to consume. The caller's val_idx / + * row_ptr / cached-bases buffers are read-only. */ +export async function runSmvpV2PairTree(opts: SmvpV2PairTreeOptions): Promise { + const { + device, shaderManager, num_subtasks, num_columns, input_size, + val_idx_buf, row_ptr_buf, point_x_buf, point_y_buf, + running_x_buf, running_y_buf, bucket_active_buf, + } = opts; + const s = opts.s ?? 16; + const tpb = opts.tpb ?? 256; + const per_thread = opts.per_thread ?? Math.max(1, Math.ceil(num_columns / tpb)); + const wgi = opts.wgi ?? 64; + const max_levels = opts.max_levels ?? 8; + if (tpb * per_thread < num_columns) { + throw new Error(`smvp_v2_pair_tree: tpb*per_thread (${tpb}*${per_thread}=${tpb * per_thread}) must be >= num_columns (${num_columns}).`); + } + + const pipelines = await compileAll(device, shaderManager, wgi, s, tpb, per_thread); + const scratch = allocScratch(device, num_columns, input_size, s, max_levels, tpb, per_thread); + const M = scratch.M; + + device.queue.writeBuffer(scratch.metaParams, 0, new Uint32Array([num_columns, num_columns, 0, 0])); + device.queue.writeBuffer(scratch.activeParams, 0, new Uint32Array([input_size, 0, 0, 0])); + device.queue.writeBuffer(scratch.plannerParams, 0, new Uint32Array([num_columns, 0, 0, 0])); + device.queue.writeBuffer(scratch.marshalConsts, 0, new Uint32Array([M, 0, 0, 0])); + device.queue.writeBuffer(scratch.scatterConsts, 0, new Uint32Array([M, 0, 0, 0])); + device.queue.writeBuffer(scratch.carryConsts, 0, new Uint32Array([M, M, 0, 0])); + device.queue.writeBuffer(scratch.v2RunParams, 0, new Uint32Array([num_columns, M, 0, 0])); + + const padPair = buildPadPair(M); + const padOff = PG * (M - 2) * PG_VEC4_BYTES; + device.queue.writeBuffer(scratch.activeA, padOff, padPair as BufferSource); + device.queue.writeBuffer(scratch.activeB, padOff, padPair as BufferSource); + device.queue.writeBuffer(scratch.activeA, PG * M * PG_VEC4_BYTES + padOff, padPair as BufferSource); + device.queue.writeBuffer(scratch.activeB, PG * M * PG_VEC4_BYTES + padOff, padPair as BufferSource); + + const encoder = device.createCommandEncoder(); + let totalPasses = 0; + const directPass = (pipe: GPUComputePipeline, bind: GPUBindGroup, x: number, y = 1, z = 1): void => { + const pass = encoder.beginComputePass(); + pass.setPipeline(pipe); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(x, y, z); + pass.end(); + totalPasses++; + }; + const indirectPass = (pipe: GPUComputePipeline, bind: GPUBindGroup, argsBuf: GPUBuffer, byteOffset: number): void => { + const pass = encoder.beginComputePass(); + pass.setPipeline(pipe); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroupsIndirect(argsBuf, byteOffset); + pass.end(); + totalPasses++; + }; + + const valIdxStride = input_size * 4; + const rowPtrStride = (num_columns + 1) * 4; + const runningStride = num_columns * ELEMENT_BYTES; + const bucketActiveStride = num_columns * 4; + + for (let st = 0; st < num_subtasks; st++) { + const valIdxView = { buffer: val_idx_buf, offset: st * valIdxStride, size: valIdxStride } as const; + const rowPtrView = { buffer: row_ptr_buf, offset: st * rowPtrStride, size: rowPtrStride } as const; + + const metaBind = device.createBindGroup({ + layout: pipelines.layouts.meta, + entries: [ + { binding: 0, resource: rowPtrView }, + { binding: 1, resource: { buffer: scratch.countsA } }, + { binding: 2, resource: { buffer: scratch.offsetsA } }, + { binding: 3, resource: { buffer: scratch.metaParams } }, + ], + }); + directPass(pipelines.csrMeta, metaBind, Math.ceil(num_columns / wgi)); + + const activeBind = device.createBindGroup({ + layout: pipelines.layouts.active, + entries: [ + { binding: 0, resource: valIdxView }, + { binding: 1, resource: { buffer: point_x_buf } }, + { binding: 2, resource: { buffer: point_y_buf } }, + { binding: 3, resource: { buffer: scratch.activeA, offset: 0, size: PG * M * PG_VEC4_BYTES } }, + { binding: 4, resource: { buffer: scratch.activeA, offset: PG * M * PG_VEC4_BYTES, size: PG * M * PG_VEC4_BYTES } }, + { binding: 5, resource: { buffer: scratch.activeParams } }, + ], + }); + directPass(pipelines.csrActive, activeBind, Math.ceil(input_size / wgi)); + + let curActive: GPUBuffer = scratch.activeA; + let nextActive: GPUBuffer = scratch.activeB; + let curCounts: GPUBuffer = scratch.countsA; + let curOffsets: GPUBuffer = scratch.offsetsA; + let nextCounts: GPUBuffer = scratch.countsB; + let nextOffsets: GPUBuffer = scratch.offsetsB; + + for (let lvl = 0; lvl < max_levels; lvl++) { + const chunkPlanBuf = scratch.perLevelChunkPlan[lvl]; + const scatterPlanBuf = scratch.perLevelScatterPlan[lvl]; + const carryPlanBuf = scratch.perLevelCarryPlan[lvl]; + const totalsBuf = scratch.perLevelTotals[lvl]; + + const plannerBind = device.createBindGroup({ + layout: pipelines.layouts.planner, + entries: [ + { binding: 0, resource: { buffer: curCounts } }, + { binding: 1, resource: { buffer: curOffsets } }, + { binding: 2, resource: { buffer: chunkPlanBuf } }, + { binding: 3, resource: { buffer: scatterPlanBuf } }, + { binding: 4, resource: { buffer: carryPlanBuf } }, + { binding: 5, resource: { buffer: nextCounts } }, + { binding: 6, resource: { buffer: nextOffsets } }, + { binding: 7, resource: { buffer: totalsBuf } }, + { binding: 8, resource: { buffer: scratch.plannerParams } }, + ], + }); + directPass(pipelines.planner, plannerBind, 1); + + const marshalBind = device.createBindGroup({ + layout: pipelines.layouts.marshal, + entries: [ + { binding: 0, resource: { buffer: chunkPlanBuf } }, + { binding: 1, resource: { buffer: curActive } }, + { binding: 2, resource: { buffer: scratch.chainBuf } }, + { binding: 3, resource: { buffer: totalsBuf } }, + { binding: 4, resource: { buffer: scratch.marshalConsts } }, + ], + }); + indirectPass(pipelines.marshal, marshalBind, totalsBuf, 16); + + const disjointBind = device.createBindGroup({ + layout: pipelines.layouts.disjoint, + entries: [ + { binding: 0, resource: { buffer: scratch.chainBuf } }, + { binding: 1, resource: { buffer: scratch.chainBuf } }, + { binding: 2, resource: { buffer: scratch.tempOut } }, + { binding: 3, resource: { buffer: totalsBuf } }, + ], + }); + indirectPass(pipelines.disjoint, disjointBind, totalsBuf, 16); + + const scatterBind = device.createBindGroup({ + layout: pipelines.layouts.scatter, + entries: [ + { binding: 0, resource: { buffer: scatterPlanBuf } }, + { binding: 1, resource: { buffer: scratch.tempOut } }, + { binding: 2, resource: { buffer: nextActive } }, + { binding: 3, resource: { buffer: totalsBuf } }, + { binding: 4, resource: { buffer: scratch.scatterConsts } }, + ], + }); + indirectPass(pipelines.scatter, scatterBind, totalsBuf, 16); + + const carryBind = device.createBindGroup({ + layout: pipelines.layouts.carry, + entries: [ + { binding: 0, resource: { buffer: carryPlanBuf } }, + { binding: 1, resource: { buffer: curActive } }, + { binding: 2, resource: { buffer: nextActive } }, + { binding: 3, resource: { buffer: totalsBuf } }, + { binding: 4, resource: { buffer: scratch.carryConsts } }, + ], + }); + indirectPass(pipelines.carry, carryBind, totalsBuf, 28); + + [curActive, nextActive] = [nextActive, curActive]; + [curCounts, nextCounts] = [nextCounts, curCounts]; + [curOffsets, nextOffsets] = [nextOffsets, curOffsets]; + } + + const subtaskBucketOff = st * num_columns; + const v2RunBind = device.createBindGroup({ + layout: pipelines.layouts.v2Run, + entries: [ + { binding: 0, resource: { buffer: curActive } }, + { binding: 1, resource: { buffer: curCounts } }, + { binding: 2, resource: { buffer: curOffsets } }, + { binding: 3, resource: { buffer: running_x_buf, offset: subtaskBucketOff * ELEMENT_BYTES, size: runningStride } }, + { binding: 4, resource: { buffer: running_y_buf, offset: subtaskBucketOff * ELEMENT_BYTES, size: runningStride } }, + { binding: 5, resource: { buffer: bucket_active_buf, offset: subtaskBucketOff * 4, size: bucketActiveStride } }, + { binding: 6, resource: { buffer: scratch.v2RunParams } }, + ], + }); + directPass(pipelines.v2ToRunning, v2RunBind, Math.ceil(num_columns / wgi)); + } + + const t0 = performance.now(); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + const gpu_wall_ms = performance.now() - t0; + + destroyScratch(scratch); + + return { + levels_per_window: max_levels, + num_subtasks, + num_columns, + total_passes: totalPasses, + gpu_wall_ms, + }; +} + export function maxChunksUpperBound(input_size: number, num_columns: number, s: number): number { return Math.max(1, Math.ceil(input_size / 2 / s) + num_columns); } -/** - * Buffer-byte-size helpers — kept here so the production msm.ts - * integration can pre-allocate matching scratch when wiring v2 in - * behind a flag. - */ export const sizes = { - /** Combined-SoA active_sums byte size for one window, including the pad pair. */ activeSumsBytes(input_size: number): number { const M = input_size + 2; return 2 * PG * M * 16; }, - /** chain_buf byte size for one window. */ chainBufBytes(input_size: number, num_columns: number, s: number): number { const T = maxChunksUpperBound(input_size, num_columns, s); return 2 * PG * (2 * s * T) * 16; }, - /** tempOut byte size for one window. */ tempOutBytes(input_size: number, num_columns: number, s: number): number { const T = maxChunksUpperBound(input_size, num_columns, s); return 2 * PG * (s * T) * 16; }, - /** chunk_plan byte size per level. */ chunkPlanBytes(input_size: number, num_columns: number, s: number): number { const T = maxChunksUpperBound(input_size, num_columns, s); return 2 * s * T * 4; }, - /** scatter_plan byte size per level. */ scatterPlanBytes(input_size: number, num_columns: number, s: number): number { const T = maxChunksUpperBound(input_size, num_columns, s); return s * T * 4; }, - /** carry_plan byte size per level. */ carryPlanBytes(num_columns: number): number { return 2 * num_columns * 4; }, - /** counts byte size per level. */ countsBytes(num_columns: number): number { return num_columns * 4; }, - /** offsets byte size per level. */ offsetsBytes(num_columns: number): number { return num_columns * 4; }, diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index f9869415faaf..36aa506f4b74 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1,6 +1,6 @@ // AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT. // Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate. -// 65 shader sources inlined. +// 70 shader sources inlined. /* eslint-disable */ @@ -1407,6 +1407,52 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_carry_copy_prod = `{{> structs }} + +// Carry-copy kernel — prod variant for the v2 pair-tree integration. +// num_carries is read from the planner's totals[1] and dispatch is +// indirect via totals[7..9]. + +const PG: u32 = 2u; + +@group(0) @binding(0) var carry_plan: array; +@group(0) @binding(1) var active_sums_old: array>; +@group(0) @binding(2) var active_sums_new: array>; +@group(0) @binding(3) var totals: array; +@group(0) @binding(4) var consts: vec4; +// consts.x = M_old +// consts.y = M_new + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = totals[1]; + let M_old = consts.x; + let M_new = consts.y; + let t = gid.x; + if (t >= T) { return; } + + let src_idx = carry_plan[2u * t + 0u]; + let dst_idx = carry_plan[2u * t + 1u]; + + let old_plane_x = 0u * PG * M_old; + let old_plane_y = 1u * PG * M_old; + let new_plane_x = 0u * PG * M_new; + let new_plane_y = 1u * PG * M_new; + + let src_x = old_plane_x + PG * src_idx; + let src_y = old_plane_y + PG * src_idx; + let dst_x = new_plane_x + PG * dst_idx; + let dst_y = new_plane_y + PG * dst_idx; + + active_sums_new[dst_x + 0u] = active_sums_old[src_x + 0u]; + active_sums_new[dst_x + 1u] = active_sums_old[src_x + 1u]; + active_sums_new[dst_y + 0u] = active_sums_old[src_y + 0u]; + active_sums_new[dst_y + 1u] = active_sums_old[src_y + 1u]; + + {{{ recompile }}} +} +`; + export const ba_fused_super_bench = `{{> structs }} {{> bigint_funcs }} {{> montgomery_product_funcs }} @@ -1740,6 +1786,73 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_marshal_pairs_prod = `{{> structs }} + +// Marshal kernel — prod variant for the v2 pair-tree integration. +// +// Same indexing math as ba_marshal_pairs_bench. The only structural +// change: the per-level T (= num_chunks) is read from the planner's +// totals[3] storage output instead of a host-set uniform, and the +// host dispatches via dispatchWorkgroupsIndirect(totals, 16). This +// dispatches exactly ceil(num_chunks / WG) workgroups so no pad +// chunks are computed. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var chunk_plan: array; +@group(0) @binding(1) var active_sums: array>; +@group(0) @binding(2) var chain_buf: array>; +@group(0) @binding(3) var totals: array; +@group(0) @binding(4) var consts: vec4; +// consts.x = M_in + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = totals[3]; + let M_in = consts.x; + let t = gid.x; + if (t >= T) { return; } + + let chain_N = 2u * S * T; + let chain_plane_x = 0u * PG * chain_N; + let chain_plane_y = 1u * PG * chain_N; + + let active_plane_x = 0u * PG * M_in; + let active_plane_y = 1u * PG * M_in; + + let chunk_base = 2u * S * t; + for (var k: u32 = 0u; k < S; k = k + 1u) { + let idx_l = chunk_plan[chunk_base + 2u * k + 0u]; + let idx_r = chunk_plan[chunk_base + 2u * k + 1u]; + + let e_l = t + (2u * k + 0u) * T; + let e_r = t + (2u * k + 1u) * T; + + let src_lx = active_plane_x + PG * idx_l; + let src_ly = active_plane_y + PG * idx_l; + let src_rx = active_plane_x + PG * idx_r; + let src_ry = active_plane_y + PG * idx_r; + + let dst_lx = chain_plane_x + PG * e_l; + let dst_ly = chain_plane_y + PG * e_l; + let dst_rx = chain_plane_x + PG * e_r; + let dst_ry = chain_plane_y + PG * e_r; + + chain_buf[dst_lx + 0u] = active_sums[src_lx + 0u]; + chain_buf[dst_lx + 1u] = active_sums[src_lx + 1u]; + chain_buf[dst_ly + 0u] = active_sums[src_ly + 0u]; + chain_buf[dst_ly + 1u] = active_sums[src_ly + 1u]; + chain_buf[dst_rx + 0u] = active_sums[src_rx + 0u]; + chain_buf[dst_rx + 1u] = active_sums[src_rx + 1u]; + chain_buf[dst_ry + 0u] = active_sums[src_ry + 0u]; + chain_buf[dst_ry + 1u] = active_sums[src_ry + 1u]; + } + + {{{ recompile }}} +} +`; + export const ba_marshal_tree_l0_bench = `{{> structs }} // Marshal kernel for the bench-msm-tree pair-tree pipeline: transposes @@ -2117,6 +2230,127 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_pair_disjoint_tree_prod = `{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Disjoint pair-sum kernel — prod variant for the v2 pair-tree +// integration. Same disjoint pair-sum math as +// ba_pair_disjoint_tree_bench (suffix-product single fr_inv_by_a per +// chunk + lean affine add); the per-level T (= num_chunks) is read +// from the planner's totals[3] storage output and the dispatch happens +// indirectly so only real chunks run. Always uses the final-mode +// strided write (matches what ba_scatter_pairs_prod expects). +// +// LAYOUT: same as the bench variant. Combined-SoA input/output (2 +// planes, PG=2 vec4 per element, plane-major then element-major then +// vec4 within an element). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var inp: array>; +@group(0) @binding(1) var unused: array>; +@group(0) @binding(2) var outp: array>; +@group(0) @binding(3) var totals: array; + +fn load_in(plane: u32, t: u32, i: u32, T: u32, N_in: u32) -> BigInt { + let plane_base = plane * PG * N_in; + let base = plane_base + PG * (t + i * T); + let q0 = inp[base + 0u]; + let q1 = inp[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_out_simple(plane: u32, t: u32, k: u32, T_curr: u32, N_out: u32, val: ptr) { + let plane_base = plane * PG * N_out; + let elem = t + k * T_curr; + let base = plane_base + PG * elem; + let w = pack_limbs_to_256(val); + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T_curr = totals[3]; + let N_in = 2u * S * T_curr; + let N_out = S * T_curr; + + let t = gid.x; + if (t >= T_curr) { return; } + + var pref: array; + var acc: BigInt = get_r(); + for (var k: u32 = 0u; k < S; k = k + 1u) { + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in); + var dx: BigInt = fr_sub(&p_rx, &p_lx); + if (k == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[k] = acc; + } + + var inv: BigInt = fr_inv_by_a(acc); + + for (var jj: u32 = 0u; jj < S; jj = jj + 1u) { + let k = S - 1u - jj; + + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in); + var p_ly: BigInt = load_in(1u, t, 2u * k + 0u, T_curr, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in); + var p_ry: BigInt = load_in(1u, t, 2u * k + 1u, T_curr, N_in); + + var inv_dx: BigInt; + if (k == 0u) { + inv_dx = inv; + } else { + var pp = pref[k - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda: BigInt = fr_sub(&p_ry, &p_ly); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_lx); + r_x = fr_sub(&r_x, &p_rx); + var r_y: BigInt = fr_sub(&p_lx, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &p_ly); + + store_out_simple(0u, t, k, T_curr, N_out, &r_x); + store_out_simple(1u, t, k, T_curr, N_out, &r_y); + + if (k > 0u) { + var dx_back: BigInt = fr_sub(&p_rx, &p_lx); + inv = montgomery_product(&inv, &dx_back); + } + } + + {{{ recompile }}} +} +`; + export const ba_planner_bench = `{{> structs }} // GPU-side bin-packing planner for the v3 MSM bucket-accumulate @@ -2384,6 +2618,178 @@ fn main(@builtin(local_invocation_id) lid: vec3) { } `; +export const ba_planner_v2_prod = `{{> structs }} + +// Production GPU bin-packing planner for the v2 pair-tree integration. +// +// Same algorithm as ba_planner_v2_bench (one workgroup of TPB threads, +// per-thread local scan, workgroup-wide Hillis-Steele scan over the +// three running sums, per-thread scatter) but extends the totals +// output with the indirect-dispatch counts the production marshal / +// disjoint / scatter / carry kernels need: +// +// totals[0] = total_pairs +// totals[1] = total_carries +// totals[2] = total_new +// totals[3] = num_chunks = max(1, (total_pairs + S - 1) / S) +// totals[4] = marshal/disjoint/scatter dispatch X (= ceil(num_chunks / WGI)) +// totals[5] = 1 +// totals[6] = 1 +// totals[7] = carry dispatch X (= ceil(total_carries / WGI)) +// totals[8] = 1 +// totals[9] = 1 +// +// The four prod-variant downstream kernels (ba_marshal_pairs_prod, +// ba_pair_disjoint_tree_prod, ba_scatter_pairs_prod, ba_carry_copy_prod) +// read num_chunks and total_carries from this same totals storage +// buffer so a single planner dispatch fully drives the level's runtime +// shape with zero wasted-pad-chunk compute. The host orchestrator +// reuses the totals buffer as the indirect-dispatch source via +// dispatchWorkgroupsIndirect(totals, 16) for marshal/disjoint/scatter +// (totals u32 indices 4..6) and dispatchWorkgroupsIndirect(totals, 28) +// for carry (totals u32 indices 7..9). +// +// Compile-time constants: +// TPB : workgroup size (e.g. 256) +// PER_THREAD : buckets per thread +// PAIR_CAP : per-bucket pair-count bound +// S : chunk size in pairs +// WGI : downstream kernel workgroup size — must match the +// workgroup_size of ba_marshal_pairs_prod / +// ba_pair_disjoint_tree_prod / ba_scatter_pairs_prod / +// ba_carry_copy_prod. + +const TPB: u32 = {{ workgroup_size }}u; +const PER_THREAD: u32 = {{ per_thread }}u; +const PAIR_CAP: u32 = {{ pair_cap }}u; +const S: u32 = {{ s }}u; +const WGI: u32 = {{ wgi }}u; + +@group(0) @binding(0) var counts: array; +@group(0) @binding(1) var offsets: array; +@group(0) @binding(2) var chunk_plan: array; +@group(0) @binding(3) var scatter_plan: array; +@group(0) @binding(4) var carry_plan: array; +@group(0) @binding(5) var new_counts: array; +@group(0) @binding(6) var new_offsets: array; +@group(0) @binding(7) var totals: array; +@group(0) @binding(8) var params: vec4; +// params.x = B + +var pair_scan: array; +var carry_scan: array; +var new_scan: array; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; + let B = params.x; + + var local_pc: array; + var local_cf: array; + var local_nc: array; + var sum_p: u32 = 0u; + var sum_c: u32 = 0u; + var sum_n: u32 = 0u; + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tid * PER_THREAD + k; + var pc: u32 = 0u; + var cf: u32 = 0u; + var nc: u32 = 0u; + if (b < B) { + let n = counts[b]; + pc = n / 2u; + cf = n & 1u; + nc = pc + cf; + } + local_pc[k] = pc; + local_cf[k] = cf; + local_nc[k] = nc; + sum_p += pc; + sum_c += cf; + sum_n += nc; + } + + pair_scan[tid] = sum_p; + carry_scan[tid] = sum_c; + new_scan[tid] = sum_n; + workgroupBarrier(); + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var add_p: u32 = 0u; + var add_c: u32 = 0u; + var add_n: u32 = 0u; + if (tid >= stride) { + add_p = pair_scan[tid - stride]; + add_c = carry_scan[tid - stride]; + add_n = new_scan[tid - stride]; + } + workgroupBarrier(); + if (tid >= stride) { + pair_scan[tid] = pair_scan[tid] + add_p; + carry_scan[tid] = carry_scan[tid] + add_c; + new_scan[tid] = new_scan[tid] + add_n; + } + workgroupBarrier(); + } + var local_pair_off: u32 = pair_scan[tid] - sum_p; + var local_carry_off: u32 = carry_scan[tid] - sum_c; + var local_new_off: u32 = new_scan[tid] - sum_n; + + if (tid == TPB - 1u) { + let tp = pair_scan[tid]; + let tc = carry_scan[tid]; + let tn = new_scan[tid]; + totals[0] = tp; + totals[1] = tc; + totals[2] = tn; + let num_chunks = max(1u, (tp + S - 1u) / S); + totals[3] = num_chunks; + totals[4] = max(1u, (num_chunks + WGI - 1u) / WGI); + totals[5] = 1u; + totals[6] = 1u; + totals[7] = max(1u, (tc + WGI - 1u) / WGI); + totals[8] = 1u; + totals[9] = 1u; + } + + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tid * PER_THREAD + k; + if (b >= B) { break; } + + let pc = local_pc[k]; + let cf = local_cf[k]; + let nc = local_nc[k]; + new_counts[b] = nc; + new_offsets[b] = local_new_off; + + let bucket_base = offsets[b]; + + for (var j: u32 = 0u; j < PAIR_CAP; j = j + 1u) { + if (j >= pc) { break; } + let global_slot = local_pair_off + j; + let chunk_id = global_slot / S; + let slot_in_chunk = global_slot % S; + let cp_base = 2u * (chunk_id * S + slot_in_chunk); + chunk_plan[cp_base + 0u] = bucket_base + 2u * j; + chunk_plan[cp_base + 1u] = bucket_base + 2u * j + 1u; + scatter_plan[chunk_id * S + slot_in_chunk] = local_new_off + j; + } + + if (cf != 0u) { + let cs = local_carry_off; + carry_plan[2u * cs + 0u] = bucket_base + counts[b] - 1u; + carry_plan[2u * cs + 1u] = local_new_off + pc; + } + + local_pair_off += pc; + local_carry_off += cf; + local_new_off += nc; + } + + {{{ recompile }}} +} +`; + export const ba_rev_packed_carry_bench = `{{> structs }} {{> bigint_funcs }} {{> montgomery_product_funcs }} @@ -2621,6 +3027,56 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_scatter_pairs_prod = `{{> structs }} + +// Scatter kernel — prod variant for the v2 pair-tree integration. +// Same per-bucket placement math as ba_scatter_pairs_bench; T is read +// from the planner's totals[3] and the dispatch is indirect via +// totals[4..6]. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var scatter_plan: array; +@group(0) @binding(1) var disjoint_out: array>; +@group(0) @binding(2) var active_sums_new: array>; +@group(0) @binding(3) var totals: array; +@group(0) @binding(4) var consts: vec4; +// consts.x = M_new + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = totals[3]; + let M_new = consts.x; + let t = gid.x; + if (t >= T) { return; } + + let out_N = S * T; + let out_plane_x = 0u * PG * out_N; + let out_plane_y = 1u * PG * out_N; + + let new_plane_x = 0u * PG * M_new; + let new_plane_y = 1u * PG * M_new; + + for (var k: u32 = 0u; k < S; k = k + 1u) { + let e = t + k * T; + let dst_idx = scatter_plan[t * S + k]; + + let src_x = out_plane_x + PG * e; + let src_y = out_plane_y + PG * e; + let dst_x = new_plane_x + PG * dst_idx; + let dst_y = new_plane_y + PG * dst_idx; + + active_sums_new[dst_x + 0u] = disjoint_out[src_x + 0u]; + active_sums_new[dst_x + 1u] = disjoint_out[src_x + 1u]; + active_sums_new[dst_y + 0u] = disjoint_out[src_y + 0u]; + active_sums_new[dst_y + 1u] = disjoint_out[src_y + 1u]; + } + + {{{ recompile }}} +} +`; + export const ba_tail_reduce_bench = `{{> structs }} {{> bigint_funcs }} {{> montgomery_product_funcs }} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_carry_copy_prod.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_carry_copy_prod.template.wgsl new file mode 100644 index 000000000000..c3b1b12787ed --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_carry_copy_prod.template.wgsl @@ -0,0 +1,44 @@ +{{> structs }} + +// Carry-copy kernel — prod variant for the v2 pair-tree integration. +// num_carries is read from the planner's totals[1] and dispatch is +// indirect via totals[7..9]. + +const PG: u32 = 2u; + +@group(0) @binding(0) var carry_plan: array; +@group(0) @binding(1) var active_sums_old: array>; +@group(0) @binding(2) var active_sums_new: array>; +@group(0) @binding(3) var totals: array; +@group(0) @binding(4) var consts: vec4; +// consts.x = M_old +// consts.y = M_new + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = totals[1]; + let M_old = consts.x; + let M_new = consts.y; + let t = gid.x; + if (t >= T) { return; } + + let src_idx = carry_plan[2u * t + 0u]; + let dst_idx = carry_plan[2u * t + 1u]; + + let old_plane_x = 0u * PG * M_old; + let old_plane_y = 1u * PG * M_old; + let new_plane_x = 0u * PG * M_new; + let new_plane_y = 1u * PG * M_new; + + let src_x = old_plane_x + PG * src_idx; + let src_y = old_plane_y + PG * src_idx; + let dst_x = new_plane_x + PG * dst_idx; + let dst_y = new_plane_y + PG * dst_idx; + + active_sums_new[dst_x + 0u] = active_sums_old[src_x + 0u]; + active_sums_new[dst_x + 1u] = active_sums_old[src_x + 1u]; + active_sums_new[dst_y + 0u] = active_sums_old[src_y + 0u]; + active_sums_new[dst_y + 1u] = active_sums_old[src_y + 1u]; + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_pairs_prod.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_pairs_prod.template.wgsl new file mode 100644 index 000000000000..3cff285af1ca --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_pairs_prod.template.wgsl @@ -0,0 +1,65 @@ +{{> structs }} + +// Marshal kernel — prod variant for the v2 pair-tree integration. +// +// Same indexing math as ba_marshal_pairs_bench. The only structural +// change: the per-level T (= num_chunks) is read from the planner's +// totals[3] storage output instead of a host-set uniform, and the +// host dispatches via dispatchWorkgroupsIndirect(totals, 16). This +// dispatches exactly ceil(num_chunks / WG) workgroups so no pad +// chunks are computed. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var chunk_plan: array; +@group(0) @binding(1) var active_sums: array>; +@group(0) @binding(2) var chain_buf: array>; +@group(0) @binding(3) var totals: array; +@group(0) @binding(4) var consts: vec4; +// consts.x = M_in + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = totals[3]; + let M_in = consts.x; + let t = gid.x; + if (t >= T) { return; } + + let chain_N = 2u * S * T; + let chain_plane_x = 0u * PG * chain_N; + let chain_plane_y = 1u * PG * chain_N; + + let active_plane_x = 0u * PG * M_in; + let active_plane_y = 1u * PG * M_in; + + let chunk_base = 2u * S * t; + for (var k: u32 = 0u; k < S; k = k + 1u) { + let idx_l = chunk_plan[chunk_base + 2u * k + 0u]; + let idx_r = chunk_plan[chunk_base + 2u * k + 1u]; + + let e_l = t + (2u * k + 0u) * T; + let e_r = t + (2u * k + 1u) * T; + + let src_lx = active_plane_x + PG * idx_l; + let src_ly = active_plane_y + PG * idx_l; + let src_rx = active_plane_x + PG * idx_r; + let src_ry = active_plane_y + PG * idx_r; + + let dst_lx = chain_plane_x + PG * e_l; + let dst_ly = chain_plane_y + PG * e_l; + let dst_rx = chain_plane_x + PG * e_r; + let dst_ry = chain_plane_y + PG * e_r; + + chain_buf[dst_lx + 0u] = active_sums[src_lx + 0u]; + chain_buf[dst_lx + 1u] = active_sums[src_lx + 1u]; + chain_buf[dst_ly + 0u] = active_sums[src_ly + 0u]; + chain_buf[dst_ly + 1u] = active_sums[src_ly + 1u]; + chain_buf[dst_rx + 0u] = active_sums[src_rx + 0u]; + chain_buf[dst_rx + 1u] = active_sums[src_rx + 1u]; + chain_buf[dst_ry + 0u] = active_sums[src_ry + 0u]; + chain_buf[dst_ry + 1u] = active_sums[src_ry + 1u]; + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_tree_prod.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_tree_prod.template.wgsl new file mode 100644 index 000000000000..3c7b7b3d4504 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_tree_prod.template.wgsl @@ -0,0 +1,119 @@ +{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Disjoint pair-sum kernel — prod variant for the v2 pair-tree +// integration. Same disjoint pair-sum math as +// ba_pair_disjoint_tree_bench (suffix-product single fr_inv_by_a per +// chunk + lean affine add); the per-level T (= num_chunks) is read +// from the planner's totals[3] storage output and the dispatch happens +// indirectly so only real chunks run. Always uses the final-mode +// strided write (matches what ba_scatter_pairs_prod expects). +// +// LAYOUT: same as the bench variant. Combined-SoA input/output (2 +// planes, PG=2 vec4 per element, plane-major then element-major then +// vec4 within an element). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var inp: array>; +@group(0) @binding(1) var unused: array>; +@group(0) @binding(2) var outp: array>; +@group(0) @binding(3) var totals: array; + +fn load_in(plane: u32, t: u32, i: u32, T: u32, N_in: u32) -> BigInt { + let plane_base = plane * PG * N_in; + let base = plane_base + PG * (t + i * T); + let q0 = inp[base + 0u]; + let q1 = inp[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_out_simple(plane: u32, t: u32, k: u32, T_curr: u32, N_out: u32, val: ptr) { + let plane_base = plane * PG * N_out; + let elem = t + k * T_curr; + let base = plane_base + PG * elem; + let w = pack_limbs_to_256(val); + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T_curr = totals[3]; + let N_in = 2u * S * T_curr; + let N_out = S * T_curr; + + let t = gid.x; + if (t >= T_curr) { return; } + + var pref: array; + var acc: BigInt = get_r(); + for (var k: u32 = 0u; k < S; k = k + 1u) { + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in); + var dx: BigInt = fr_sub(&p_rx, &p_lx); + if (k == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[k] = acc; + } + + var inv: BigInt = fr_inv_by_a(acc); + + for (var jj: u32 = 0u; jj < S; jj = jj + 1u) { + let k = S - 1u - jj; + + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in); + var p_ly: BigInt = load_in(1u, t, 2u * k + 0u, T_curr, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in); + var p_ry: BigInt = load_in(1u, t, 2u * k + 1u, T_curr, N_in); + + var inv_dx: BigInt; + if (k == 0u) { + inv_dx = inv; + } else { + var pp = pref[k - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda: BigInt = fr_sub(&p_ry, &p_ly); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_lx); + r_x = fr_sub(&r_x, &p_rx); + var r_y: BigInt = fr_sub(&p_lx, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &p_ly); + + store_out_simple(0u, t, k, T_curr, N_out, &r_x); + store_out_simple(1u, t, k, T_curr, N_out, &r_y); + + if (k > 0u) { + var dx_back: BigInt = fr_sub(&p_rx, &p_lx); + inv = montgomery_product(&inv, &dx_back); + } + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_prod.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_prod.template.wgsl new file mode 100644 index 000000000000..f3ec2f414e8e --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_prod.template.wgsl @@ -0,0 +1,170 @@ +{{> structs }} + +// Production GPU bin-packing planner for the v2 pair-tree integration. +// +// Same algorithm as ba_planner_v2_bench (one workgroup of TPB threads, +// per-thread local scan, workgroup-wide Hillis-Steele scan over the +// three running sums, per-thread scatter) but extends the totals +// output with the indirect-dispatch counts the production marshal / +// disjoint / scatter / carry kernels need: +// +// totals[0] = total_pairs +// totals[1] = total_carries +// totals[2] = total_new +// totals[3] = num_chunks = max(1, (total_pairs + S - 1) / S) +// totals[4] = marshal/disjoint/scatter dispatch X (= ceil(num_chunks / WGI)) +// totals[5] = 1 +// totals[6] = 1 +// totals[7] = carry dispatch X (= ceil(total_carries / WGI)) +// totals[8] = 1 +// totals[9] = 1 +// +// The four prod-variant downstream kernels (ba_marshal_pairs_prod, +// ba_pair_disjoint_tree_prod, ba_scatter_pairs_prod, ba_carry_copy_prod) +// read num_chunks and total_carries from this same totals storage +// buffer so a single planner dispatch fully drives the level's runtime +// shape with zero wasted-pad-chunk compute. The host orchestrator +// reuses the totals buffer as the indirect-dispatch source via +// dispatchWorkgroupsIndirect(totals, 16) for marshal/disjoint/scatter +// (totals u32 indices 4..6) and dispatchWorkgroupsIndirect(totals, 28) +// for carry (totals u32 indices 7..9). +// +// Compile-time constants: +// TPB : workgroup size (e.g. 256) +// PER_THREAD : buckets per thread +// PAIR_CAP : per-bucket pair-count bound +// S : chunk size in pairs +// WGI : downstream kernel workgroup size — must match the +// workgroup_size of ba_marshal_pairs_prod / +// ba_pair_disjoint_tree_prod / ba_scatter_pairs_prod / +// ba_carry_copy_prod. + +const TPB: u32 = {{ workgroup_size }}u; +const PER_THREAD: u32 = {{ per_thread }}u; +const PAIR_CAP: u32 = {{ pair_cap }}u; +const S: u32 = {{ s }}u; +const WGI: u32 = {{ wgi }}u; + +@group(0) @binding(0) var counts: array; +@group(0) @binding(1) var offsets: array; +@group(0) @binding(2) var chunk_plan: array; +@group(0) @binding(3) var scatter_plan: array; +@group(0) @binding(4) var carry_plan: array; +@group(0) @binding(5) var new_counts: array; +@group(0) @binding(6) var new_offsets: array; +@group(0) @binding(7) var totals: array; +@group(0) @binding(8) var params: vec4; +// params.x = B + +var pair_scan: array; +var carry_scan: array; +var new_scan: array; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; + let B = params.x; + + var local_pc: array; + var local_cf: array; + var local_nc: array; + var sum_p: u32 = 0u; + var sum_c: u32 = 0u; + var sum_n: u32 = 0u; + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tid * PER_THREAD + k; + var pc: u32 = 0u; + var cf: u32 = 0u; + var nc: u32 = 0u; + if (b < B) { + let n = counts[b]; + pc = n / 2u; + cf = n & 1u; + nc = pc + cf; + } + local_pc[k] = pc; + local_cf[k] = cf; + local_nc[k] = nc; + sum_p += pc; + sum_c += cf; + sum_n += nc; + } + + pair_scan[tid] = sum_p; + carry_scan[tid] = sum_c; + new_scan[tid] = sum_n; + workgroupBarrier(); + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var add_p: u32 = 0u; + var add_c: u32 = 0u; + var add_n: u32 = 0u; + if (tid >= stride) { + add_p = pair_scan[tid - stride]; + add_c = carry_scan[tid - stride]; + add_n = new_scan[tid - stride]; + } + workgroupBarrier(); + if (tid >= stride) { + pair_scan[tid] = pair_scan[tid] + add_p; + carry_scan[tid] = carry_scan[tid] + add_c; + new_scan[tid] = new_scan[tid] + add_n; + } + workgroupBarrier(); + } + var local_pair_off: u32 = pair_scan[tid] - sum_p; + var local_carry_off: u32 = carry_scan[tid] - sum_c; + var local_new_off: u32 = new_scan[tid] - sum_n; + + if (tid == TPB - 1u) { + let tp = pair_scan[tid]; + let tc = carry_scan[tid]; + let tn = new_scan[tid]; + totals[0] = tp; + totals[1] = tc; + totals[2] = tn; + let num_chunks = max(1u, (tp + S - 1u) / S); + totals[3] = num_chunks; + totals[4] = max(1u, (num_chunks + WGI - 1u) / WGI); + totals[5] = 1u; + totals[6] = 1u; + totals[7] = max(1u, (tc + WGI - 1u) / WGI); + totals[8] = 1u; + totals[9] = 1u; + } + + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tid * PER_THREAD + k; + if (b >= B) { break; } + + let pc = local_pc[k]; + let cf = local_cf[k]; + let nc = local_nc[k]; + new_counts[b] = nc; + new_offsets[b] = local_new_off; + + let bucket_base = offsets[b]; + + for (var j: u32 = 0u; j < PAIR_CAP; j = j + 1u) { + if (j >= pc) { break; } + let global_slot = local_pair_off + j; + let chunk_id = global_slot / S; + let slot_in_chunk = global_slot % S; + let cp_base = 2u * (chunk_id * S + slot_in_chunk); + chunk_plan[cp_base + 0u] = bucket_base + 2u * j; + chunk_plan[cp_base + 1u] = bucket_base + 2u * j + 1u; + scatter_plan[chunk_id * S + slot_in_chunk] = local_new_off + j; + } + + if (cf != 0u) { + let cs = local_carry_off; + carry_plan[2u * cs + 0u] = bucket_base + counts[b] - 1u; + carry_plan[2u * cs + 1u] = local_new_off + pc; + } + + local_pair_off += pc; + local_carry_off += cf; + local_new_off += nc; + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_scatter_pairs_prod.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_scatter_pairs_prod.template.wgsl new file mode 100644 index 000000000000..4a14e8539736 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_scatter_pairs_prod.template.wgsl @@ -0,0 +1,48 @@ +{{> structs }} + +// Scatter kernel — prod variant for the v2 pair-tree integration. +// Same per-bucket placement math as ba_scatter_pairs_bench; T is read +// from the planner's totals[3] and the dispatch is indirect via +// totals[4..6]. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var scatter_plan: array; +@group(0) @binding(1) var disjoint_out: array>; +@group(0) @binding(2) var active_sums_new: array>; +@group(0) @binding(3) var totals: array; +@group(0) @binding(4) var consts: vec4; +// consts.x = M_new + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = totals[3]; + let M_new = consts.x; + let t = gid.x; + if (t >= T) { return; } + + let out_N = S * T; + let out_plane_x = 0u * PG * out_N; + let out_plane_y = 1u * PG * out_N; + + let new_plane_x = 0u * PG * M_new; + let new_plane_y = 1u * PG * M_new; + + for (var k: u32 = 0u; k < S; k = k + 1u) { + let e = t + k * T; + let dst_idx = scatter_plan[t * S + k]; + + let src_x = out_plane_x + PG * e; + let src_y = out_plane_y + PG * e; + let dst_x = new_plane_x + PG * dst_idx; + let dst_y = new_plane_y + PG * dst_idx; + + active_sums_new[dst_x + 0u] = disjoint_out[src_x + 0u]; + active_sums_new[dst_x + 1u] = disjoint_out[src_x + 1u]; + active_sums_new[dst_y + 0u] = disjoint_out[src_y + 0u]; + active_sums_new[dst_y + 1u] = disjoint_out[src_y + 1u]; + } + + {{{ recompile }}} +} From 664eb1ca7582f6498761d40284f5db8f98ab5a95 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 14:46:06 +0000 Subject: [PATCH 26/33] test(bb/msm): noble oracle for v2 prod orchestrator + drop pad-pair init Adds dev/msm-webgpu/bench-msm-oracle-prod.{ts,html}: a focused correctness harness that exercises runSmvpV2PairTree end-to-end on real BN254 points and cross-checks each non-empty bucket's reduced sum against @noble/curves projective addition. Defaults to one window, B=32 buckets, N=256 points (TPB=64 PER=1 for the planner so TPB*PER_THREAD>=B); query params allow scaling up. Also drops the pad-pair init that I'd carried over from the bench- msm-tree-v2 host-planner convention. With the prod planner + indirect dispatch the marshal / disjoint / scatter / carry kernels only run for thread IDs in [0, num_chunks_or_num_carries), so the chunk_plan / scatter_plan / carry_plan tail entries (which would have pointed at the pad slots in the host-planner world) are never read. The pad slots in active_sums never get consulted; initialising them is dead work. --- .../dev/msm-webgpu/bench-msm-oracle-prod.html | 22 ++ .../dev/msm-webgpu/bench-msm-oracle-prod.ts | 349 ++++++++++++++++++ .../msm-webgpu/scripts/run-browserstack.mjs | 1 + .../src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts | 20 +- 4 files changed, 377 insertions(+), 15 deletions(-) create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.html create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.ts diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.html new file mode 100644 index 000000000000..5126b5f73cc0 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.html @@ -0,0 +1,22 @@ + + + + + v2 prod orchestrator + indirect dispatch noble oracle (WebGPU) + + + +

v2 prod orchestrator + indirect dispatch noble oracle

+

Query params: ?subtasks=T&columns=B&input=N&s=S&wgi=W&tpb=TPB&per=PER&lvls=L&seed=K

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.ts new file mode 100644 index 000000000000..10e2be473956 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.ts @@ -0,0 +1,349 @@ +/// +// End-to-end correctness oracle for the prod v2 pair-tree orchestrator. +// +// Wraps runSmvpV2PairTree (cuzk/smvp_v2_pair_tree.ts) with a noble-CPU +// cross-check on real BN254 affine points. Validates the full prod +// path: csr_to_v2_meta + csr_to_v2_active_sums + planner_v2_prod + +// marshal_prod + disjoint_prod + scatter_prod + carry_prod + +// v2_to_running, with indirect dispatch driven by the planner's +// per-level totals. +// +// Sizing: small by design (num_subtasks=1, num_columns=32, N=256) so +// noble's projective bucket sum runs instantly in the browser. Each +// bucket gets ~8 points; pair-tree needs ~4 levels. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD, modInverse } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { runSmvpV2PairTree } from '../../src/msm_webgpu/cuzk/smvp_v2_pair_tree.js'; +import { makeResultsClient } from './results_post.js'; +import { bn254 } from '@noble/curves/bn254'; + +let NUM_SUBTASKS = 1; +let NUM_COLUMNS = 32; +let INPUT_SIZE = 256; +let S = 16; +let WGI = 64; +let TPB = 64; +let PER_THREAD = 1; +let MAX_LEVELS = 8; +let SEED = 0xc0de; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function bigintToPackedU32x8(v: bigint): Uint32Array { + const w = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + w[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return w; +} + +function packedU32x8ToBigint(w: Uint32Array, off: number): bigint { + let v = 0n; + for (let i = 7; i >= 0; i--) v = (v << 32n) | BigInt(w[off + i] >>> 0); + return v; +} + +interface OracleResult { + num_subtasks: number; + num_columns: number; + input_size: number; + s: number; + wgi: number; + tpb: number; + per_thread: number; + max_levels: number; + total_passes: number; + gpu_wall_ms: number; + buckets_checked: number; + buckets_passed: number; + first_mismatches: Array<{ subtask: number; bucket: number; count: number; gpu_x: string; gpu_y: string; ref_x: string; ref_y: string; ok: boolean }>; + all_passed: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: Record | null; + results: OracleResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] }; +(window as unknown as { __bench: BenchState }).__bench = benchState; +const resultsClient = makeResultsClient({ page: 'bench-msm-oracle-prod' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, + params: benchState.params, + results: benchState.results, + error: benchState.error, + log: benchState.log, + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-msm-oracle-prod] ${msg}`); +} + +async function readbackU32(device: GPUDevice, buf: GPUBuffer, bytes: number): Promise { + const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, bytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const out = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + return out; +} + +async function runOracle(device: GPUDevice, sm: ShaderManager, R: bigint, Rinv: bigint, p: bigint): Promise { + log('info', `=== T=${NUM_SUBTASKS} B=${NUM_COLUMNS} N=${INPUT_SIZE} S=${S} WGI=${WGI} TPB=${TPB} PER=${PER_THREAD} MAX_LEVELS=${MAX_LEVELS}`); + const rng = makeRng(SEED); + const G1 = bn254.G1.ProjectivePoint; + const order = bn254.fields.Fr.ORDER; + + const points: Array<{ x: bigint; y: bigint }> = []; + const pointXWords = new Uint32Array(INPUT_SIZE * 8); + const pointYWords = new Uint32Array(INPUT_SIZE * 8); + for (let i = 0; i < INPUT_SIZE; i++) { + let k = 0n; + for (let w = 0; w < 8; w++) k = (k << 32n) | BigInt(rng() >>> 0); + k = k % order; + if (k === 0n) k = 1n; + const aff = G1.BASE.multiply(k).toAffine(); + points.push({ x: aff.x, y: aff.y }); + pointXWords.set(bigintToPackedU32x8((aff.x * R) % p), 8 * i); + pointYWords.set(bigintToPackedU32x8((aff.y * R) % p), 8 * i); + } + log('info', `generated ${INPUT_SIZE} BN254 affine points`); + + const valIdxArr = new Uint32Array(NUM_SUBTASKS * INPUT_SIZE); + const rowPtrArr = new Uint32Array(NUM_SUBTASKS * (NUM_COLUMNS + 1)); + const bucketOf: Uint32Array[] = []; + for (let st = 0; st < NUM_SUBTASKS; st++) { + const bucket = new Uint32Array(INPUT_SIZE); + const counts = new Uint32Array(NUM_COLUMNS); + for (let i = 0; i < INPUT_SIZE; i++) { + const hi = (rng() >>> 16) & 0xffff; + const lo = (rng() >>> 16) & 0xffff; + const v = hi * 0x10000 + lo; + const b = v % NUM_COLUMNS; + bucket[i] = b; + counts[b]++; + } + bucketOf.push(bucket); + const offsets = new Uint32Array(NUM_COLUMNS + 1); + for (let b = 0; b < NUM_COLUMNS; b++) offsets[b + 1] = offsets[b] + counts[b]; + const cursor = new Uint32Array(NUM_COLUMNS); + for (let i = 0; i < INPUT_SIZE; i++) { + const b = bucket[i]; + const slot = offsets[b] + cursor[b]++; + valIdxArr[st * INPUT_SIZE + slot] = i; + } + const rpBase = st * (NUM_COLUMNS + 1); + for (let b = 0; b <= NUM_COLUMNS; b++) rowPtrArr[rpBase + b] = offsets[b]; + } + log('info', `built synthetic CSR for ${NUM_SUBTASKS} window(s)`); + + const mk = (bytes: number, extra: GPUBufferUsageFlags = 0): GPUBuffer => + device.createBuffer({ size: bytes, usage: GPUBufferUsage.STORAGE | extra }); + const val_idx_buf = mk(valIdxArr.byteLength, GPUBufferUsage.COPY_DST); + const row_ptr_buf = mk(rowPtrArr.byteLength, GPUBufferUsage.COPY_DST); + const point_x_buf = mk(INPUT_SIZE * 32, GPUBufferUsage.COPY_DST); + const point_y_buf = mk(INPUT_SIZE * 32, GPUBufferUsage.COPY_DST); + const running_x_buf = mk(NUM_SUBTASKS * NUM_COLUMNS * 32, GPUBufferUsage.COPY_SRC); + const running_y_buf = mk(NUM_SUBTASKS * NUM_COLUMNS * 32, GPUBufferUsage.COPY_SRC); + const bucket_active_buf = mk(NUM_SUBTASKS * NUM_COLUMNS * 4, GPUBufferUsage.COPY_SRC); + + device.queue.writeBuffer(val_idx_buf, 0, valIdxArr as BufferSource); + device.queue.writeBuffer(row_ptr_buf, 0, rowPtrArr as BufferSource); + device.queue.writeBuffer(point_x_buf, 0, pointXWords as BufferSource); + device.queue.writeBuffer(point_y_buf, 0, pointYWords as BufferSource); + + const stats = await runSmvpV2PairTree({ + device, + shaderManager: sm, + num_subtasks: NUM_SUBTASKS, + num_columns: NUM_COLUMNS, + input_size: INPUT_SIZE, + s: S, + tpb: TPB, + per_thread: PER_THREAD, + wgi: WGI, + max_levels: MAX_LEVELS, + val_idx_buf, row_ptr_buf, point_x_buf, point_y_buf, + running_x_buf, running_y_buf, bucket_active_buf, + }); + log('info', `runSmvpV2PairTree returned: ${JSON.stringify(stats)}`); + + const runningXWords = await readbackU32(device, running_x_buf, NUM_SUBTASKS * NUM_COLUMNS * 32); + const runningYWords = await readbackU32(device, running_y_buf, NUM_SUBTASKS * NUM_COLUMNS * 32); + const bucketActive = await readbackU32(device, bucket_active_buf, NUM_SUBTASKS * NUM_COLUMNS * 4); + + const refSumPerBucket = new Map(); + for (let st = 0; st < NUM_SUBTASKS; st++) { + const bucket = bucketOf[st]; + for (let b = 0; b < NUM_COLUMNS; b++) { + let acc = G1.ZERO; + let count = 0; + for (let i = 0; i < INPUT_SIZE; i++) { + if (bucket[i] !== b) continue; + acc = acc.add(G1.fromAffine({ x: points[i].x, y: points[i].y })); + count++; + } + const bucket_global = st * NUM_COLUMNS + b; + refSumPerBucket.set(bucket_global, count === 0 ? null : (acc.is0() ? null : acc.toAffine())); + } + } + + const checks: OracleResult['first_mismatches'] = []; + const mismatches: OracleResult['first_mismatches'] = []; + let passCount = 0; + for (let st = 0; st < NUM_SUBTASKS; st++) { + const bucket = bucketOf[st]; + const counts = new Uint32Array(NUM_COLUMNS); + for (let i = 0; i < INPUT_SIZE; i++) counts[bucket[i]]++; + for (let b = 0; b < NUM_COLUMNS; b++) { + const bucket_global = st * NUM_COLUMNS + b; + const ref = refSumPerBucket.get(bucket_global) ?? null; + const active = bucketActive[bucket_global]; + if (counts[b] === 0) { + if (active !== 0) mismatches.push({ subtask: st, bucket: b, count: 0, gpu_x: 'active=1', gpu_y: '', ref_x: 'empty', ref_y: '', ok: false }); + continue; + } + if (active === 0) { + mismatches.push({ subtask: st, bucket: b, count: counts[b], gpu_x: 'active=0', gpu_y: '', ref_x: ref ? ref.x.toString(16) : 'INF', ref_y: ref ? ref.y.toString(16) : 'INF', ok: false }); + continue; + } + const xMont = packedU32x8ToBigint(runningXWords, bucket_global * 8); + const yMont = packedU32x8ToBigint(runningYWords, bucket_global * 8); + const gx = (xMont * Rinv) % p; + const gy = (yMont * Rinv) % p; + const ok = ref !== null && gx === ref.x && gy === ref.y; + const entry = { subtask: st, bucket: b, count: counts[b], gpu_x: gx.toString(16), gpu_y: gy.toString(16), ref_x: ref ? ref.x.toString(16) : 'INF', ref_y: ref ? ref.y.toString(16) : 'INF', ok }; + checks.push(entry); + if (ok) passCount++; + else if (mismatches.length < 8) mismatches.push(entry); + } + } + const allPassed = mismatches.length === 0; + + if (allPassed) { + log('ok', `oracle PASS — ${passCount}/${checks.length} buckets match noble reference (prod orchestrator)`); + } else { + log('err', `oracle FAIL — ${mismatches.length} mismatches in first ${checks.length} buckets`); + for (const m of mismatches.slice(0, 8)) { + log('err', ` subtask=${m.subtask} bucket=${m.bucket} count=${m.count}`); + log('err', ` gpu: x=${m.gpu_x} y=${m.gpu_y}`); + log('err', ` ref: x=${m.ref_x} y=${m.ref_y}`); + } + } + + val_idx_buf.destroy(); + row_ptr_buf.destroy(); + point_x_buf.destroy(); + point_y_buf.destroy(); + running_x_buf.destroy(); + running_y_buf.destroy(); + bucket_active_buf.destroy(); + + return { + num_subtasks: NUM_SUBTASKS, + num_columns: NUM_COLUMNS, + input_size: INPUT_SIZE, + s: S, + wgi: WGI, + tpb: TPB, + per_thread: PER_THREAD, + max_levels: MAX_LEVELS, + total_passes: stats.total_passes, + gpu_wall_ms: stats.gpu_wall_ms, + buckets_checked: checks.length, + buckets_passed: passCount, + first_mismatches: mismatches, + all_passed: allPassed, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + if (qp.get('subtasks')) NUM_SUBTASKS = parseInt(qp.get('subtasks')!, 10); + if (qp.get('columns')) NUM_COLUMNS = parseInt(qp.get('columns')!, 10); + if (qp.get('input')) INPUT_SIZE = parseInt(qp.get('input')!, 10); + if (qp.get('s')) S = parseInt(qp.get('s')!, 10); + if (qp.get('wgi')) WGI = parseInt(qp.get('wgi')!, 10); + if (qp.get('tpb')) TPB = parseInt(qp.get('tpb')!, 10); + if (qp.get('per')) PER_THREAD = parseInt(qp.get('per')!, 10); + if (qp.get('lvls')) MAX_LEVELS = parseInt(qp.get('lvls')!, 10); + if (qp.get('seed')) SEED = parseInt(qp.get('seed')!, 10); + return { subtasks: NUM_SUBTASKS, columns: NUM_COLUMNS, input: INPUT_SIZE, s: S, wgi: WGI, tpb: TPB, per_thread: PER_THREAD, max_levels: MAX_LEVELS, seed: SEED }; +} + +async function main() { + try { + if (!('gpu' in navigator)) throw new Error('navigator.gpu missing'); + const params = parseParams(); + benchState.params = params; + log('info', `params: ${JSON.stringify(params)}`); + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + const Rinv = modInverse(R, p); + const sm = new ShaderManager(NUM_SUBTASKS, NUM_COLUMNS, BN254_CURVE_CONFIG, false); + const r = await runOracle(device, sm, R, Rinv, p); + benchState.results.push(r); + resultsClient.postProgress({ + kind: 'oracle_prod_done', + all_passed: r.all_passed, + buckets_passed: r.buckets_passed, + buckets_checked: r.buckets_checked, + gpu_wall_ms: r.gpu_wall_ms, + }); + benchState.state = 'done'; + log('ok', 'done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + }) + .finally(() => { + postFinal().catch(() => {}); + }); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs index b9c2fc8ef1e7..4bb7c19a9751 100644 --- a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs +++ b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs @@ -137,6 +137,7 @@ const pageMap = { "bench-planner": "/dev/msm-webgpu/bench-planner.html", "bench-csr-to-v2": "/dev/msm-webgpu/bench-csr-to-v2.html", "bench-msm-oracle": "/dev/msm-webgpu/bench-msm-oracle.html", + "bench-msm-oracle-prod": "/dev/msm-webgpu/bench-msm-oracle-prod.html", "bench-smvp-tree": "/dev/msm-webgpu/bench-smvp-tree.html", sanity: "/dev/msm-webgpu/index.html", }; diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts b/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts index de9979bd1237..9d4b153d4272 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts @@ -339,15 +339,6 @@ function destroyScratch(scratch: Scratch): void { scratch.v2RunParams.destroy(); } -function buildPadPair(M: number): Uint32Array { - const padPair = new Uint32Array(2 * PG * 2 * 4); - for (let i = 0; i < padPair.length; i++) { - padPair[i] = (0x9e3779b9 * (i + 1)) >>> 0; - } - if (padPair[0] === padPair[PG * 4]) padPair[PG * 4] ^= 1; - return padPair; -} - /** * Run the v2 pair-tree MSM bucket-accumulate for ALL pippenger windows * in a single GPU submit. @@ -384,12 +375,11 @@ export async function runSmvpV2PairTree(opts: SmvpV2PairTreeOptions): Promise Date: Wed, 20 May 2026 15:51:45 +0000 Subject: [PATCH 27/33] fix(bb/msm): single-binding active_sums in csr_to_v2 converter The csr_to_v2_active_sums shader had separate bindings for active_sums_x and active_sums_y. The standalone bench-csr-to-v2 was fine because it binds those to two physically distinct buffers, but the smvp_v2_pair_tree orchestrator binds them to two subviews of one combined-SoA scratch buffer at offsets 0 and PG*M*PG_VEC4_BYTES. For M = input_size + 2 = 258 that y-plane offset is 8256 bytes, which is not aligned to WebGPU's default minStorageBufferOffsetAlignment of 256. On M2 Chrome 148 the orchestrator's dispatch chain proceeded silently with no writes landing in any output (running_x_buf / running_y_buf / bucket_active_buf all readback-zero), and the v2_to_running adapter was reporting bucket_active = 0 for every bucket -- even with a diagnostic shader that wrote a constant unconditionally. Fix: the shader now takes a single combined active_sums binding and reads M from params.y. Plane x lives at active_sums[PG*slot + v]; plane y lives at active_sums[PG*M + PG*slot + v]. This matches the combined-SoA layout the four v2 pair-tree kernels (marshal_pairs / pair_disjoint_tree / scatter_pairs / carry_copy) already use, and lets the orchestrator bind the whole scratch.activeA buffer to a single entry. bench-csr-to-v2 host validation is updated to read back from the single combined buffer and compare each plane against the host reference. --- .../ts/dev/msm-webgpu/bench-csr-to-v2.ts | 62 +++++++------------ .../src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts | 10 ++- .../src/msm_webgpu/wgsl/_generated/shaders.ts | 56 +++++++++++------ .../cuzk/csr_to_v2_active_sums.template.wgsl | 56 +++++++++++------ 4 files changed, 102 insertions(+), 82 deletions(-) diff --git a/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.ts b/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.ts index 9fdb7ab4394e..44a0458b96d1 100644 --- a/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.ts +++ b/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.ts @@ -215,12 +215,13 @@ async function runOne(device: GPUDevice, sm: ShaderManager): Promise> 1, probeStart, INPUT_SIZE - 1]) { - const slot = viBase + k; - const ptIdx = valIdx[slot]; - for (let w = 0; w < 8; w++) { - const got = gpuActiveX[slot * 8 + w]; - const want = refBasesX[ptIdx * 8 + w]; - if (got !== want) { - xFails++; - if (mismatches.length < 8) mismatches.push(`activeX[s=${s} k=${k} w=${w}]: gpu=${got} ref=${want}`); - } - const gotY = gpuActiveY[slot * 8 + w]; - const wantY = refBasesY[ptIdx * 8 + w]; - if (gotY !== wantY) { - yFails++; - if (mismatches.length < 8) mismatches.push(`activeY[s=${s} k=${k} w=${w}]: gpu=${gotY} ref=${wantY}`); - } + const planeYU32Off = 2 * 2 * M; + for (let slot = 0; slot < totalSlots; slot++) { + const ptIdx = valIdx[slot]; + for (let w = 0; w < 8; w++) { + const got = gpuActive[slot * 8 + w]; + const want = refBasesX[ptIdx * 8 + w]; + if (got !== want) { + xFails++; + if (mismatches.length < 8) mismatches.push(`activeX[slot=${slot} w=${w}]: gpu=${got} ref=${want}`); + } + const gotY = gpuActive[planeYU32Off + slot * 8 + w]; + const wantY = refBasesY[ptIdx * 8 + w]; + if (gotY !== wantY) { + yFails++; + if (mismatches.length < 8) mismatches.push(`activeY[slot=${slot} w=${w}]: gpu=${gotY} ref=${wantY}`); } } } - // Full-pass byte compare (cheap; the buffers are u32 arrays). - for (let k = 0; k < totalSlots * 8; k++) { - const ptIdx = valIdx[k >> 3]; - const w = k & 7; - if (gpuActiveX[k] !== refBasesX[ptIdx * 8 + w]) xFails++; - if (gpuActiveY[k] !== refBasesY[ptIdx * 8 + w]) yFails++; - } let mFails = 0; for (let s = 0; s < NUM_SUBTASKS; s++) { const rpBase = s * (NUM_COLUMNS + 1); @@ -443,8 +430,7 @@ async function runOne(device: GPUDevice, sm: ShaderManager): Promise per field -// element). The copy is a raw element copy — destination element bytes -// equal source element bytes; no unpack / pack needed. -// -// Sign handling: cuZK encodes signed slices via bucket index, not via -// point negation, so the converter does not flip y. The finalize pass -// negates y for negative-bucket contributions. +// The copy is a raw element copy — destination element bytes equal +// source element bytes; no unpack / pack needed. Sign handling stays at +// finalize (cuZK encodes signed slices via bucket index, not via point +// negation). + +const PG: u32 = 2u; @group(0) @binding(0) var val_idx: array; @@ -6764,12 +6776,12 @@ var new_point_x: array>; @group(0) @binding(2) var new_point_y: array>; @group(0) @binding(3) -var active_sums_x: array>; -@group(0) @binding(4) -var active_sums_y: array>; +var active_sums: array>; -// params[0] = total_slots (num_subtasks * input_size) -@group(0) @binding(5) +// params.x = total_slots (num_subtasks * input_size, OR per-window +// input_size when the caller binds val_idx as a per-window subview) +// params.y = M (elements per plane in active_sums) +@group(0) @binding(4) var params: vec4; @compute @@ -6781,12 +6793,18 @@ fn main(@builtin(global_invocation_id) gid: vec3) { return; } + let M = params[1]; let pt_idx = val_idx[slot]; - active_sums_x[2u * slot] = new_point_x[2u * pt_idx]; - active_sums_x[2u * slot + 1u] = new_point_x[2u * pt_idx + 1u]; - active_sums_y[2u * slot] = new_point_y[2u * pt_idx]; - active_sums_y[2u * slot + 1u] = new_point_y[2u * pt_idx + 1u]; + let plane_x_base = PG * slot; + let plane_y_base = PG * M + PG * slot; + let src_x = PG * pt_idx; + let src_y = PG * pt_idx; + + active_sums[plane_x_base + 0u] = new_point_x[src_x + 0u]; + active_sums[plane_x_base + 1u] = new_point_x[src_x + 1u]; + active_sums[plane_y_base + 0u] = new_point_y[src_y + 0u]; + active_sums[plane_y_base + 1u] = new_point_y[src_y + 1u]; {{{ recompile }}} } diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_active_sums.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_active_sums.template.wgsl index e52219756a56..7abf9766e692 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_active_sums.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_active_sums.template.wgsl @@ -5,18 +5,30 @@ // at the indices listed in val_idx (cuZK transpose output, bucket-major // per subtask). // -// Per (subtask s, slot k) thread, with slot = s * input_size + k: -// pt_idx = val_idx[slot] -// active_sums_x[slot] = new_point_x[pt_idx] -// active_sums_y[slot] = new_point_y[pt_idx] +// active_sums is one combined-SoA storage buffer (matching what the v2 +// pair-tree kernels marshal_pairs / pair_disjoint_tree / scatter_pairs +// / carry_copy consume): +// plane 0 (x) at vec4 indices [0, PG * M) +// plane 1 (y) at vec4 indices [PG * M, 2 * PG * M) +// per-element layout: PG=2 vec4 at [PG*elem, PG*elem+1]. +// M (elements per plane) is passed via params.y so this shader uses a +// single storage binding instead of two subviews of the same buffer — +// the subview path tripped a silent dispatch no-op on M2 Chrome 148 +// because plane-y's byte offset (PG*M*16 = 8256 for M=258) is not a +// multiple of WebGPU's default minStorageBufferOffsetAlignment of 256. // -// Both source and destination are packed 8×u32 (two vec4 per field -// element). The copy is a raw element copy — destination element bytes -// equal source element bytes; no unpack / pack needed. +// Per (subtask s, slot k) thread with slot = s * input_size + k: +// pt_idx = val_idx[slot] +// active_sums[PG * slot + v] = new_point_x[PG * pt_idx + v] +// active_sums[PG * M + PG * slot + v] = new_point_y[PG * pt_idx + v] +// for v in {0, 1}. // -// Sign handling: cuZK encodes signed slices via bucket index, not via -// point negation, so the converter does not flip y. The finalize pass -// negates y for negative-bucket contributions. +// The copy is a raw element copy — destination element bytes equal +// source element bytes; no unpack / pack needed. Sign handling stays at +// finalize (cuZK encodes signed slices via bucket index, not via point +// negation). + +const PG: u32 = 2u; @group(0) @binding(0) var val_idx: array; @@ -25,12 +37,12 @@ var new_point_x: array>; @group(0) @binding(2) var new_point_y: array>; @group(0) @binding(3) -var active_sums_x: array>; -@group(0) @binding(4) -var active_sums_y: array>; +var active_sums: array>; -// params[0] = total_slots (num_subtasks * input_size) -@group(0) @binding(5) +// params.x = total_slots (num_subtasks * input_size, OR per-window +// input_size when the caller binds val_idx as a per-window subview) +// params.y = M (elements per plane in active_sums) +@group(0) @binding(4) var params: vec4; @compute @@ -42,12 +54,18 @@ fn main(@builtin(global_invocation_id) gid: vec3) { return; } + let M = params[1]; let pt_idx = val_idx[slot]; - active_sums_x[2u * slot] = new_point_x[2u * pt_idx]; - active_sums_x[2u * slot + 1u] = new_point_x[2u * pt_idx + 1u]; - active_sums_y[2u * slot] = new_point_y[2u * pt_idx]; - active_sums_y[2u * slot + 1u] = new_point_y[2u * pt_idx + 1u]; + let plane_x_base = PG * slot; + let plane_y_base = PG * M + PG * slot; + let src_x = PG * pt_idx; + let src_y = PG * pt_idx; + + active_sums[plane_x_base + 0u] = new_point_x[src_x + 0u]; + active_sums[plane_x_base + 1u] = new_point_x[src_x + 1u]; + active_sums[plane_y_base + 0u] = new_point_y[src_y + 0u]; + active_sums[plane_y_base + 1u] = new_point_y[src_y + 1u]; {{{ recompile }}} } From 57e4544d783727d1c3ac6757cf3e0bf01cc57eb7 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 15:59:09 +0000 Subject: [PATCH 28/33] fix(bb/msm): planner_v2_prod pad-fills partial last chunk The orchestrator passed lvls=0 cleanly after the previous fix (10/32 buckets matched noble) but the full lvls=8 run still mismatched on 22/32 buckets. Root cause: planner_v2_prod only writes the first total_pairs entries of chunk_plan and scatter_plan; the last chunk is partial (total_pairs is not a multiple of S) so its tail slots stay stale. The marshal kernel iterates ALL S=16 slots per dispatched chunk regardless of how many real pairs the chunk holds, so the stale tail slots feed garbage (idx_l, idx_r) into the disjoint pair-sum, which then scatters the garbage to whatever stale scatter_plan dst happens to be -- corrupting downstream bucket reductions. Fix: add a Phase E (tid==TPB-1, after a workgroupBarrier) to planner_v2_prod that fills chunk_plan[total_pairs .. num_chunks*S) with (pad_left, pad_right) and scatter_plan with discard. The constants come from new uniform fields params.y/z/w. The orchestrator reserves 3 tail slots in active_sums (M = input_size + 3) for pad_left, pad_right, and discard at indices input_size, input_size+1, input_size+2; initialises pad_left and pad_right in both ping-pong buffers with distinct-x packed-Mont values so the disjoint kernel's lean affine add is well-defined; and writes the indices into plannerParams. Indirect-dispatch-only-no-pad-fill (which I'd reached for in the prior commit) handles the OVER-allocation case (num_chunks=0 -> 0 workgroups dispatched) but does nothing for the PARTIAL-last-chunk case. Both are needed; the planner now does the partial-chunk pad and drops the max(1u, ...) clamp from totals[3..7] so num_chunks=0 truly dispatches zero workgroups. --- .../src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts | 34 +++++++++++++++---- .../src/msm_webgpu/wgsl/_generated/shaders.ts | 24 +++++++++++-- .../cuzk/ba_planner_v2_prod.template.wgsl | 24 +++++++++++-- 3 files changed, 69 insertions(+), 13 deletions(-) diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts b/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts index 2d3c2b87dbec..51ae0e363a34 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts @@ -261,7 +261,12 @@ function allocScratch( tpb: number, per_thread: number, ): Scratch { - const M = input_size + 2; + // M = real slots (input_size) + 3 reserved tail slots: pad_left, + // pad_right, discard. Reserved slots aren't touched by the converter + // or by real bucket reductions (which only address [0, total_actives) < + // input_size), so the planner can safely pad-fill the last partial + // chunk of chunk_plan / scatter_plan with these constant indices. + const M = input_size + 3; const maxChunks = Math.max(1, Math.ceil(input_size / 2 / s) + 1); const mk = (bytes: number, extra: GPUBufferUsageFlags = 0): GPUBuffer => @@ -366,19 +371,34 @@ export async function runSmvpV2PairTree(opts: SmvpV2PairTreeOptions): Promise>> 0; + if (padPair[0] === padPair[PG * 4]) padPair[PG * 4] ^= 1; + const planeXPadByteOff = PG * padLeft * PG_VEC4_BYTES; + const planeYPadByteOff = PG * M * PG_VEC4_BYTES + PG * padLeft * PG_VEC4_BYTES; + device.queue.writeBuffer(scratch.activeA, planeXPadByteOff, padPair as BufferSource); + device.queue.writeBuffer(scratch.activeA, planeYPadByteOff, padPair as BufferSource); + device.queue.writeBuffer(scratch.activeB, planeXPadByteOff, padPair as BufferSource); + device.queue.writeBuffer(scratch.activeB, planeYPadByteOff, padPair as BufferSource); const encoder = device.createCommandEncoder(); let totalPasses = 0; diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index e43d2a9040b9..2b56919af6db 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -2675,6 +2675,9 @@ const WGI: u32 = {{ wgi }}u; @group(0) @binding(7) var totals: array; @group(0) @binding(8) var params: vec4; // params.x = B +// params.y = pad_left_idx (active_sums index used for chunk_plan tail pad left operand) +// params.z = pad_right_idx (chunk_plan tail pad right operand; must differ from pad_left_idx in x) +// params.w = discard_idx (scatter_plan tail dst; slot that the next level never reads) var pair_scan: array; var carry_scan: array; @@ -2742,12 +2745,12 @@ fn main(@builtin(local_invocation_id) lid: vec3) { totals[0] = tp; totals[1] = tc; totals[2] = tn; - let num_chunks = max(1u, (tp + S - 1u) / S); + let num_chunks = (tp + S - 1u) / S; totals[3] = num_chunks; - totals[4] = max(1u, (num_chunks + WGI - 1u) / WGI); + totals[4] = (num_chunks + WGI - 1u) / WGI; totals[5] = 1u; totals[6] = 1u; - totals[7] = max(1u, (tc + WGI - 1u) / WGI); + totals[7] = (tc + WGI - 1u) / WGI; totals[8] = 1u; totals[9] = 1u; } @@ -2786,6 +2789,21 @@ fn main(@builtin(local_invocation_id) lid: vec3) { local_new_off += nc; } + workgroupBarrier(); + if (tid == TPB - 1u) { + let tp = pair_scan[tid]; + let num_chunks = (tp + S - 1u) / S; + let pad_end = num_chunks * S; + let pad_left = params.y; + let pad_right = params.z; + let discard = params.w; + for (var i: u32 = tp; i < pad_end; i = i + 1u) { + chunk_plan[2u * i + 0u] = pad_left; + chunk_plan[2u * i + 1u] = pad_right; + scatter_plan[i] = discard; + } + } + {{{ recompile }}} } `; diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_prod.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_prod.template.wgsl index f3ec2f414e8e..0f0ca81ecda8 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_prod.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_prod.template.wgsl @@ -55,6 +55,9 @@ const WGI: u32 = {{ wgi }}u; @group(0) @binding(7) var totals: array; @group(0) @binding(8) var params: vec4; // params.x = B +// params.y = pad_left_idx (active_sums index used for chunk_plan tail pad left operand) +// params.z = pad_right_idx (chunk_plan tail pad right operand; must differ from pad_left_idx in x) +// params.w = discard_idx (scatter_plan tail dst; slot that the next level never reads) var pair_scan: array; var carry_scan: array; @@ -122,12 +125,12 @@ fn main(@builtin(local_invocation_id) lid: vec3) { totals[0] = tp; totals[1] = tc; totals[2] = tn; - let num_chunks = max(1u, (tp + S - 1u) / S); + let num_chunks = (tp + S - 1u) / S; totals[3] = num_chunks; - totals[4] = max(1u, (num_chunks + WGI - 1u) / WGI); + totals[4] = (num_chunks + WGI - 1u) / WGI; totals[5] = 1u; totals[6] = 1u; - totals[7] = max(1u, (tc + WGI - 1u) / WGI); + totals[7] = (tc + WGI - 1u) / WGI; totals[8] = 1u; totals[9] = 1u; } @@ -166,5 +169,20 @@ fn main(@builtin(local_invocation_id) lid: vec3) { local_new_off += nc; } + workgroupBarrier(); + if (tid == TPB - 1u) { + let tp = pair_scan[tid]; + let num_chunks = (tp + S - 1u) / S; + let pad_end = num_chunks * S; + let pad_left = params.y; + let pad_right = params.z; + let discard = params.w; + for (var i: u32 = tp; i < pad_end; i = i + 1u) { + chunk_plan[2u * i + 0u] = pad_left; + chunk_plan[2u * i + 1u] = pad_right; + scatter_plan[i] = discard; + } + } + {{{ recompile }}} } From 3c0758b59e3e4a37f2da75fd5b1c21e7011572ba Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 16:14:48 +0000 Subject: [PATCH 29/33] fix(bb/msm): rename 'discard' to 'discard_idx' in planner_v2_prod WGSL reserves the 'discard' keyword (fragment-shader statement) even inside compute shaders, so 'let discard = params.w;' failed parse with "expected identifier for 'let' declaration". The rename keeps the pad-fill semantics identical. --- barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts | 4 ++-- .../src/msm_webgpu/wgsl/cuzk/ba_planner_v2_prod.template.wgsl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index 2b56919af6db..bef7e380beeb 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -2796,11 +2796,11 @@ fn main(@builtin(local_invocation_id) lid: vec3) { let pad_end = num_chunks * S; let pad_left = params.y; let pad_right = params.z; - let discard = params.w; + let discard_idx = params.w; for (var i: u32 = tp; i < pad_end; i = i + 1u) { chunk_plan[2u * i + 0u] = pad_left; chunk_plan[2u * i + 1u] = pad_right; - scatter_plan[i] = discard; + scatter_plan[i] = discard_idx; } } diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_prod.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_prod.template.wgsl index 0f0ca81ecda8..811118992dc5 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_prod.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_prod.template.wgsl @@ -176,11 +176,11 @@ fn main(@builtin(local_invocation_id) lid: vec3) { let pad_end = num_chunks * S; let pad_left = params.y; let pad_right = params.z; - let discard = params.w; + let discard_idx = params.w; for (var i: u32 = tp; i < pad_end; i = i + 1u) { chunk_plan[2u * i + 0u] = pad_left; chunk_plan[2u * i + 1u] = pad_right; - scatter_plan[i] = discard; + scatter_plan[i] = discard_idx; } } From aa401d60c3a621229ef39e2d86ba3ec74c726eb8 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 17:06:52 +0000 Subject: [PATCH 30/33] feat(bb/msm): wire use_v2_pair_tree into compute_bn254_msm_batch_affine Step 1 of the v2 -> production MSM integration: a new use_v2_pair_tree opt-in flag on compute_bn254_msm_batch_affine that plumbs through compute_curve_msm into smvp_batch_affine_gpu and replaces the cuZK schedule + batch_inverse_parallel + apply_scatter round-loop with the v2 bin-packed pair-tree orchestrator (runSmvpV2PairTree). When the flag is set, the dispatch sequence becomes: ba_init (existing; seeds running_x/y / bucket_active from CSR) encoder.finish() + queue.submit + onSubmittedWorkDone (mid-flush so init + upstream transpose are visible) runSmvpV2PairTree(...) (csr_to_v2_meta + csr_to_v2_active_sums + per-level planner_v2_prod + indirect-dispatch marshal_prod + pair_disjoint_tree_prod + scatter_pairs_prod + carry_copy_prod + v2_to_running, per window; internal single submit + onSubmittedWorkDone) fresh commandEncoder finalize_collect + finalize_apply (existing; consumes the packed running_x/y + bucket_active the orchestrator wrote) BPR + horner_reduce + readback (existing, unchanged) Forces packed 8x u32 storage (the prod kernels' native layout). Mutually exclusive with use_tree_reduce, which takes precedence if both are set. Takes precedence over fused_revcarry (both imply packed -- v2 is the strictly better packed path). The orchestrator's per-bucket reduced sums already passed the bench-msm-oracle-prod noble cross-check at N=256 / B=32 (32/32 buckets PASS, GPU wall 54 ms). This commit lets the production MSM call exercise the same path -- compute_bn254_msm_batch_affine(..., true) returns the end-to-end MSM result with the v2 orchestrator providing the bucket-accumulate stage. End-to-end noble validation against bn254.G1.msm is the next step. --- .../ts/src/msm_webgpu/cuzk/batch_affine.ts | 67 +++++++++++++++++-- barretenberg/ts/src/msm_webgpu/msm.ts | 25 ++++++- 2 files changed, 83 insertions(+), 9 deletions(-) diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts b/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts index 02dc4e9fc81f..59e8736cc069 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts @@ -25,6 +25,7 @@ import type { GpuContext } from './gpu_context.js'; import { ShaderManager } from './shader_manager.js'; import { create_bind_group_layout, execute_pipeline, execute_pipeline_indirect } from './gpu.js'; import { runTreeReduce } from './smvp_tree.js'; +import { runSmvpV2PairTree } from './smvp_v2_pair_tree.js'; // Per-stage profiling budgets for the per-round loop. Sized to cover // every active round of every family at every benchmark size, including @@ -154,6 +155,18 @@ export const smvp_batch_affine_gpu = async ( // consumption are unaffected. Mutually exclusive with use_tree_reduce // (tree path takes precedence if both are set). Currently bn254 only. fused_revcarry = false, + // When true, replace the entire round-loop (schedule + batch_inverse_parallel + // + apply_scatter) with the v2 bin-packed pair-tree orchestrator + // (runSmvpV2PairTree). The orchestrator runs the per-bucket + // accumulate on PACKED 8x u32 storage via csr_to_v2_* + planner_v2_prod + + // indirect-dispatch marshal / disjoint / scatter / carry + v2_to_running. + // ba_init still runs first (it seeds running_x/y / bucket_active / + // bucket_cursor from the CSR row pointers) and the existing + // finalize_collect / finalize_apply / BPR / horner stages consume the + // running_x/y / bucket_active that runSmvpV2PairTree writes. Forces + // packed = true. Mutually exclusive with use_tree_reduce (which takes + // precedence). Currently bn254 only. + use_v2_pair_tree = false, ): Promise => { // The tree-reduce path needs to mid-flush commandEncoder so its CSR // reads see the current call's transpose output (not stale data from @@ -166,12 +179,13 @@ export const smvp_batch_affine_gpu = async ( const num_words = shaderManager.num_words; const limb_byte_length = num_words * 4; // PACKED 8×u32 (32 B/element) storage for the running_x/y field - // workspace. Implied by the fused round path (tree-reduce takes - // precedence and is mutually exclusive). Arithmetic still unpacks to - // the BigInt limb layout in-register; only the storage bytes change. - // pair_delta / pair_inv / pair_prefix are NOT packed — they are unused - // on the fused round path and consumed BigInt-layout by finalize. - const packed = fused_revcarry && !use_tree_reduce; + // workspace. Implied by the fused round path or the v2 pair-tree + // orchestrator (tree-reduce takes precedence and is mutually + // exclusive). Arithmetic still unpacks to the BigInt limb layout + // in-register; only the storage bytes change. pair_delta / pair_inv / + // pair_prefix are NOT packed -- they are unused on the fused / v2 + // round paths and consumed BigInt-layout by finalize. + const packed = (fused_revcarry || use_v2_pair_tree) && !use_tree_reduce; const field_elem_byte_length = packed ? 32 : limb_byte_length; // WINDOWS_PER_BATCH pools the pair pools of WPB consecutive subtasks @@ -952,6 +966,47 @@ export const smvp_batch_affine_gpu = async ( // They get GC'd after the caller submits + finishes commandEncoder. // No explicit `.destroy()` here — destroying buffers referenced by // a pending bind group invalidates the encoder. + } else if (use_v2_pair_tree) { + // 2''. v2 bin-packed pair-tree orchestrator. Replaces the per-round + // schedule + batch_inverse_parallel + apply_scatter loop with a + // single per-window submit chain (csr_to_v2_meta + + // csr_to_v2_active_sums + per-level planner_v2_prod + + // indirect-dispatch marshal_prod + pair_disjoint_tree_prod + + // scatter_pairs_prod + carry_copy_prod + v2_to_running). Writes + // running_x / running_y / bucket_active in the packed 8x u32 layout + // the existing finalize_collect + finalize_apply consume. + // + // ba_init above already seeded running_x / running_y / bucket_active + // from val_idx[row_begin] and the per-bucket non-empty flags; the + // v2_to_running adapter overwrites those for non-empty buckets and + // (re)writes bucket_active for all of them. Empty buckets keep + // init's zero state, which finalize_collect skips via + // bucket_active == 0. + // + // Mid-flush the caller's encoder so init + upstream transpose (CSR + // row_ptr and val_idx) are visible to the orchestrator's reads + // before they happen. runSmvpV2PairTree runs its own command + // encoder + single submit + onSubmittedWorkDone internally; we + // hand a fresh encoder to the caller for finalize and downstream + // BPR / horner stages. + device.queue.submit([commandEncoder.finish()]); + await device.queue.onSubmittedWorkDone(); + await runSmvpV2PairTree({ + device, + shaderManager, + num_subtasks, + num_columns, + input_size, + val_idx_buf: all_csc_val_idxs_sb, + row_ptr_buf: all_csc_col_ptr_sb, + point_x_buf: point_x_sb, + point_y_buf: point_y_sb, + running_x_buf: running_x_sb, + running_y_buf: running_y_sb, + bucket_active_buf: bucket_active_sb, + }); + commandEncoder = device.createCommandEncoder(); + commandEncoderRef.current = commandEncoder; } else { // 2. Round loop, cross-subtask parallel. // diff --git a/barretenberg/ts/src/msm_webgpu/msm.ts b/barretenberg/ts/src/msm_webgpu/msm.ts index a1cc4cc25ab4..5622fc1824e6 100644 --- a/barretenberg/ts/src/msm_webgpu/msm.ts +++ b/barretenberg/ts/src/msm_webgpu/msm.ts @@ -256,6 +256,17 @@ export const compute_bn254_msm_batch_affine = async ( // batch_inverse_parallel + apply_scatter dispatches. Init / schedule / // finalize stay unchanged (same BigInt-layout buffers). bn254 only. fused_revcarry = false, + // Opt-in: replace the cuZK schedule + batch_inverse_parallel + + // apply_scatter round-loop with the v2 bin-packed pair-tree orchestrator + // (cuzk/smvp_v2_pair_tree.ts). Per window: csr_to_v2_meta + + // csr_to_v2_active_sums + per-level planner_v2_prod + indirect-dispatch + // marshal / disjoint / scatter / carry + v2_to_running. Writes the same + // running_x / running_y / bucket_active outputs the existing + // finalize_collect + finalize_apply consume, so BPR / horner stay + // unchanged. Forces packed 8x u32 storage (the prod kernels' native + // layout). Mutually exclusive with use_tree_reduce; takes precedence + // over fused_revcarry. bn254 only. + use_v2_pair_tree = false, ): Promise<{ x: bigint; y: bigint }> => compute_curve_msm( // Cached path: `baseAffinePoints` is ignored. Uint8Array cast keeps @@ -275,6 +286,7 @@ export const compute_bn254_msm_batch_affine = async ( bpr_inner_loop, use_tree_reduce, fused_revcarry, + use_v2_pair_tree, ); // GLV cold-path entry points (compute_bn254_msm_glv and @@ -514,6 +526,12 @@ const compute_curve_msm = async ( // apply_scatter dispatches. Init / schedule / finalize unchanged. // Forwarded into smvp_batch_affine_gpu. bn254 only. fused_revcarry = false, + // Opt-in: replace the cuZK schedule + batch_inverse_parallel + + // apply_scatter round-loop with the v2 bin-packed pair-tree orchestrator + // (runSmvpV2PairTree). Forces packed 8x u32 storage; mutually exclusive + // with use_tree_reduce; takes precedence over fused_revcarry. See + // compute_bn254_msm_batch_affine for the full description. + use_v2_pair_tree = false, ): Promise<{ x: bigint; y: bigint }> => { const curveParams = compute_misc_params(curveConfig.baseFieldModulus, curveConfig.wordSize); const num_words = curveParams.num_words; @@ -524,9 +542,9 @@ const compute_curve_msm = async ( // BigInt limb layout in-register; only storage-buffer bytes change. // The final gpu_horner_sums_* result buffer and the raw point input // stay BigInt-layout so host decoding is unchanged. Tree-reduce takes - // precedence (mutually exclusive with fused), so packed implies the - // batch_affine fused round path. - const packed = fused_revcarry && !use_tree_reduce; + // precedence (mutually exclusive with fused / v2-pair-tree); the v2 + // pair-tree orchestrator natively works on packed 8x u32 storage. + const packed = (fused_revcarry || use_v2_pair_tree) && !use_tree_reduce; const effective_scalar_byte_length = glv_override?.scalar_byte_length ?? curveConfig.scalarByteLength; const input_size = cached_bases ? cached_bases.input_size : (scalars as Buffer).length / effective_scalar_byte_length; @@ -966,6 +984,7 @@ const compute_curve_msm = async ( cached_bases !== undefined && context !== undefined, use_tree_reduce, fused_revcarry, + use_v2_pair_tree, ); // Tree-reduce path may have replaced the encoder; re-bind so // subsequent BPR / readback operations target the active encoder. From 9c75f45a04afa7f7218d24414cd3e683bffc94fd Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 17:33:38 +0000 Subject: [PATCH 31/33] feat(bb/msm): multi-workgroup v2 planner (3-pass scan) Splits the single-workgroup ba_planner_v2_prod into a 3-pass chain so the planner scales beyond a single workgroup's TPB * PER_THREAD bucket limit. Required for use_v2_pair_tree at production num_columns (B >= 32k) where the single-WG planner blew out register/shared-memory limits. Pass 1 (per-tile local scan): one WG per TILE = TPB * PER_THREAD buckets, writes per-bucket local exclusive prefix offsets + per-WG inclusive totals. Pass 2 (cross-tile scan): single small WG scans wg_totals into per-WG exclusive global starts, emits totals[0..9] + dispatch_arg triples, pad-fills the last partial chunk. Pass 3 (per-tile scatter): one WG per TILE, uses wg_global_start + bucket_local_off to compute global offsets and emit chunk/scatter/carry plans + new_offsets. TILE = 1024 (tpb=256, per_thread=4), scan capacity = 1024, so this covers num_columns up to 2^20 with PER_THREAD staying small enough to keep the per-thread private arrays in registers. --- .../ts/src/msm_webgpu/cuzk/shader_manager.ts | 60 +++ .../src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts | 174 +++++++-- .../src/msm_webgpu/wgsl/_generated/shaders.ts | 360 +++++++++++++++++- .../ba_planner_v2_mwg_local.template.wgsl | 124 ++++++ .../cuzk/ba_planner_v2_mwg_scan.template.wgsl | 141 +++++++ .../ba_planner_v2_mwg_scatter.template.wgsl | 87 +++++ 6 files changed, 909 insertions(+), 37 deletions(-) create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_local.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_scan.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_scatter.template.wgsl diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index 24bd0a2f4ba1..9fe03e280a49 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -15,6 +15,9 @@ import { ba_planner_bench as ba_planner_bench_shader, ba_planner_v2_bench as ba_planner_v2_bench_shader, ba_planner_v2_prod as ba_planner_v2_prod_shader, + ba_planner_v2_mwg_local as ba_planner_v2_mwg_local_shader, + ba_planner_v2_mwg_scan as ba_planner_v2_mwg_scan_shader, + ba_planner_v2_mwg_scatter as ba_planner_v2_mwg_scatter_shader, ba_marshal_pairs_prod as ba_marshal_pairs_prod_shader, ba_pair_disjoint_tree_prod as ba_pair_disjoint_tree_prod_shader, ba_scatter_pairs_prod as ba_scatter_pairs_prod_shader, @@ -940,6 +943,63 @@ ${packLines.join('\n')} ); } + /** + * Multi-workgroup v2 planner — Pass 1 of 3: per-tile local scan. + * One workgroup per TILE = workgroup_size * per_thread buckets. + * Together with gen_ba_planner_v2_mwg_scan_shader and + * gen_ba_planner_v2_mwg_scatter_shader replaces the single-WG + * gen_ba_planner_v2_prod_shader for B > a single workgroup tile. + */ + public gen_ba_planner_v2_mwg_local_shader(workgroup_size: number, per_thread: number): string { + if (workgroup_size <= 0 || per_thread <= 0 || + !Number.isInteger(workgroup_size) || !Number.isInteger(per_thread)) { + throw new Error(`gen_ba_planner_v2_mwg_local_shader: positive integer args required`); + } + return mustache.render( + ba_planner_v2_mwg_local_shader, + { workgroup_size, per_thread, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Multi-workgroup v2 planner — Pass 2 of 3: cross-tile scan, totals, + * dispatch_args, pad-fill of the partial last chunk. Single small + * workgroup. per_thread must satisfy workgroup_size * per_thread >= + * the maximum supported num_wgs (= ceil(max_B / tile_size_pass1)). + * wgi must match the downstream prod kernels' workgroup size. + */ + public gen_ba_planner_v2_mwg_scan_shader(workgroup_size: number, per_thread: number, s: number, wgi: number): string { + if (workgroup_size <= 0 || per_thread <= 0 || s <= 0 || wgi <= 0 || + !Number.isInteger(workgroup_size) || !Number.isInteger(per_thread) || !Number.isInteger(s) || !Number.isInteger(wgi)) { + throw new Error(`gen_ba_planner_v2_mwg_scan_shader: positive integer args required`); + } + return mustache.render( + ba_planner_v2_mwg_scan_shader, + { workgroup_size, per_thread, s, wgi, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Multi-workgroup v2 planner — Pass 3 of 3: per-tile scatter. Launch + * shape mirrors pass 1 (one workgroup per TILE buckets). The + * pair_cap, s, workgroup_size, and per_thread values must match + * those used to compile pass 1 / pass 2 / the downstream prod + * kernels' chunk-size S. + */ + public gen_ba_planner_v2_mwg_scatter_shader(workgroup_size: number, per_thread: number, s: number, pair_cap: number = 64): string { + if (workgroup_size <= 0 || per_thread <= 0 || s <= 0 || pair_cap <= 0 || + !Number.isInteger(workgroup_size) || !Number.isInteger(per_thread) || !Number.isInteger(s) || !Number.isInteger(pair_cap)) { + throw new Error(`gen_ba_planner_v2_mwg_scatter_shader: positive integer args required`); + } + return mustache.render( + ba_planner_v2_mwg_scatter_shader, + { workgroup_size, per_thread, pair_cap, s, recompile: this.recompile }, + { structs }, + ); + } + /** * Marshal pairs — prod variant. Reads num_chunks from * totals[3] (storage), dispatched indirectly off totals[4..6]. diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts b/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts index 51ae0e363a34..a8d059b5ffea 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts @@ -132,7 +132,9 @@ async function compilePipeline( interface Pipelines { csrMeta: GPUComputePipeline; csrActive: GPUComputePipeline; - planner: GPUComputePipeline; + plannerLocal: GPUComputePipeline; + plannerScan: GPUComputePipeline; + plannerScatter: GPUComputePipeline; marshal: GPUComputePipeline; disjoint: GPUComputePipeline; scatter: GPUComputePipeline; @@ -141,7 +143,9 @@ interface Pipelines { layouts: { meta: GPUBindGroupLayout; active: GPUBindGroupLayout; - planner: GPUBindGroupLayout; + plannerLocal: GPUBindGroupLayout; + plannerScan: GPUBindGroupLayout; + plannerScatter: GPUBindGroupLayout; marshal: GPUBindGroupLayout; disjoint: GPUBindGroupLayout; scatter: GPUBindGroupLayout; @@ -155,8 +159,10 @@ async function compileAll( sm: ShaderManager, wgi: number, s: number, - tpb: number, - per_thread: number, + tile_tpb: number, + tile_per_thread: number, + scan_tpb: number, + scan_per_thread: number, ): Promise { const layouts: Pipelines['layouts'] = { meta: device.createBindGroupLayout({ @@ -171,17 +177,39 @@ async function compileAll( uniformEntry(4), ], }), - planner: device.createBindGroupLayout({ + plannerLocal: device.createBindGroupLayout({ entries: [ roStorageEntry(0), - roStorageEntry(1), + rwStorageEntry(1), rwStorageEntry(2), rwStorageEntry(3), rwStorageEntry(4), rwStorageEntry(5), + uniformEntry(6), + ], + }), + plannerScan: device.createBindGroupLayout({ + entries: [ + rwStorageEntry(0), + rwStorageEntry(1), + rwStorageEntry(2), + rwStorageEntry(3), + uniformEntry(4), + ], + }), + plannerScatter: device.createBindGroupLayout({ + entries: [ + roStorageEntry(0), + roStorageEntry(1), + roStorageEntry(2), + roStorageEntry(3), + roStorageEntry(4), + roStorageEntry(5), rwStorageEntry(6), rwStorageEntry(7), - uniformEntry(8), + rwStorageEntry(8), + rwStorageEntry(9), + uniformEntry(10), ], }), marshal: device.createBindGroupLayout({ @@ -209,14 +237,26 @@ async function compileAll( }), }; - const [csrMeta, csrActive, planner, marshal, disjoint, scatter, carry, v2ToRunning] = await Promise.all([ + const [csrMeta, csrActive, plannerLocal, plannerScan, plannerScatter, marshal, disjoint, scatter, carry, v2ToRunning] = await Promise.all([ compilePipeline(device, layouts.meta, sm.gen_csr_to_v2_meta_shader(wgi), `csr-meta-wg${wgi}`), compilePipeline(device, layouts.active, sm.gen_csr_to_v2_active_sums_shader(wgi), `csr-active-wg${wgi}`), compilePipeline( device, - layouts.planner, - sm.gen_ba_planner_v2_prod_shader(tpb, per_thread, s, wgi, 64), - `planner-v2-prod-T${tpb}-P${per_thread}-S${s}-W${wgi}`, + layouts.plannerLocal, + sm.gen_ba_planner_v2_mwg_local_shader(tile_tpb, tile_per_thread), + `planner-v2-mwg-local-T${tile_tpb}-P${tile_per_thread}`, + ), + compilePipeline( + device, + layouts.plannerScan, + sm.gen_ba_planner_v2_mwg_scan_shader(scan_tpb, scan_per_thread, s, wgi), + `planner-v2-mwg-scan-T${scan_tpb}-P${scan_per_thread}-S${s}-W${wgi}`, + ), + compilePipeline( + device, + layouts.plannerScatter, + sm.gen_ba_planner_v2_mwg_scatter_shader(tile_tpb, tile_per_thread, s, 64), + `planner-v2-mwg-scatter-T${tile_tpb}-P${tile_per_thread}-S${s}`, ), compilePipeline(device, layouts.marshal, sm.gen_ba_marshal_pairs_prod_shader(wgi, s), `marshal-prod-W${wgi}-S${s}`), compilePipeline(device, layouts.disjoint, sm.gen_ba_pair_disjoint_tree_prod_shader(wgi, s), `disjoint-prod-W${wgi}-S${s}`), @@ -224,7 +264,7 @@ async function compileAll( compilePipeline(device, layouts.carry, sm.gen_ba_carry_copy_prod_shader(wgi), `carry-prod-W${wgi}`), compilePipeline(device, layouts.v2Run, sm.gen_v2_to_running_shader(wgi), `v2-to-running-wg${wgi}`), ]); - return { csrMeta, csrActive, planner, marshal, disjoint, scatter, carry, v2ToRunning, layouts }; + return { csrMeta, csrActive, plannerLocal, plannerScan, plannerScatter, marshal, disjoint, scatter, carry, v2ToRunning, layouts }; } interface Scratch { @@ -236,20 +276,26 @@ interface Scratch { countsB: GPUBuffer; offsetsA: GPUBuffer; offsetsB: GPUBuffer; + bucketLocalPairOff: GPUBuffer; + bucketLocalCarryOff: GPUBuffer; + bucketLocalNewOff: GPUBuffer; + wgTotals: GPUBuffer; perLevelChunkPlan: GPUBuffer[]; perLevelScatterPlan: GPUBuffer[]; perLevelCarryPlan: GPUBuffer[]; perLevelTotals: GPUBuffer[]; metaParams: GPUBuffer; activeParams: GPUBuffer; - plannerParams: GPUBuffer; + plannerLocalParams: GPUBuffer; + plannerScanParams: GPUBuffer; + plannerScatterParams: GPUBuffer; marshalConsts: GPUBuffer; scatterConsts: GPUBuffer; carryConsts: GPUBuffer; v2RunParams: GPUBuffer; M: number; maxChunks: number; - perThread: number; + numTiles: number; } function allocScratch( @@ -258,8 +304,7 @@ function allocScratch( input_size: number, s: number, max_levels: number, - tpb: number, - per_thread: number, + tile_size: number, ): Scratch { // M = real slots (input_size) + 3 reserved tail slots: pad_left, // pad_right, discard. Reserved slots aren't touched by the converter @@ -268,6 +313,7 @@ function allocScratch( // chunk of chunk_plan / scatter_plan with these constant indices. const M = input_size + 3; const maxChunks = Math.max(1, Math.ceil(input_size / 2 / s) + 1); + const numTiles = Math.max(1, Math.ceil(num_columns / tile_size)); const mk = (bytes: number, extra: GPUBufferUsageFlags = 0): GPUBuffer => device.createBuffer({ size: bytes, usage: GPUBufferUsage.STORAGE | extra }); @@ -286,6 +332,12 @@ function allocScratch( const offsetsA = mk(offsetsBytes); const offsetsB = mk(offsetsBytes); + const bucketLocalBytes = num_columns * 4; + const bucketLocalPairOff = mk(bucketLocalBytes); + const bucketLocalCarryOff = mk(bucketLocalBytes); + const bucketLocalNewOff = mk(bucketLocalBytes); + const wgTotals = mk(3 * numTiles * 4); + const perLevelChunkPlan: GPUBuffer[] = []; const perLevelScatterPlan: GPUBuffer[] = []; const perLevelCarryPlan: GPUBuffer[] = []; @@ -305,7 +357,9 @@ function allocScratch( device.createBuffer({ size: Math.max(16, bytes), usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); const metaParams = ub(16); const activeParams = ub(16); - const plannerParams = ub(16); + const plannerLocalParams = ub(16); + const plannerScanParams = ub(16); + const plannerScatterParams = ub(16); const marshalConsts = ub(16); const scatterConsts = ub(16); const carryConsts = ub(16); @@ -314,10 +368,12 @@ function allocScratch( return { activeA, activeB, chainBuf, tempOut, countsA, countsB, offsetsA, offsetsB, + bucketLocalPairOff, bucketLocalCarryOff, bucketLocalNewOff, wgTotals, perLevelChunkPlan, perLevelScatterPlan, perLevelCarryPlan, perLevelTotals, - metaParams, activeParams, plannerParams, + metaParams, activeParams, + plannerLocalParams, plannerScanParams, plannerScatterParams, marshalConsts, scatterConsts, carryConsts, v2RunParams, - M, maxChunks, perThread: per_thread, + M, maxChunks, numTiles, }; } @@ -330,13 +386,19 @@ function destroyScratch(scratch: Scratch): void { scratch.countsB.destroy(); scratch.offsetsA.destroy(); scratch.offsetsB.destroy(); + scratch.bucketLocalPairOff.destroy(); + scratch.bucketLocalCarryOff.destroy(); + scratch.bucketLocalNewOff.destroy(); + scratch.wgTotals.destroy(); for (const b of scratch.perLevelChunkPlan) b.destroy(); for (const b of scratch.perLevelScatterPlan) b.destroy(); for (const b of scratch.perLevelCarryPlan) b.destroy(); for (const b of scratch.perLevelTotals) b.destroy(); scratch.metaParams.destroy(); scratch.activeParams.destroy(); - scratch.plannerParams.destroy(); + scratch.plannerLocalParams.destroy(); + scratch.plannerScanParams.destroy(); + scratch.plannerScatterParams.destroy(); scratch.marshalConsts.destroy(); scratch.scatterConsts.destroy(); scratch.carryConsts.destroy(); @@ -359,16 +421,26 @@ export async function runSmvpV2PairTree(opts: SmvpV2PairTreeOptions): Promise= num_columns (${num_columns}).`); + + // Multi-workgroup planner sizing. TILE = tile_tpb * tile_per_thread + // is the bucket count handled by a single workgroup in passes 1 and + // 3. Pass 2 uses a single workgroup and must satisfy + // scan_tpb * scan_per_thread >= ceil(num_columns / TILE). The picked + // values (TILE=1024, scan capacity=1024) cover num_columns up to 2^20. + const tile_tpb = 256; + const tile_per_thread = 4; + const tile_size = tile_tpb * tile_per_thread; + const scan_tpb = 256; + const scan_per_thread = 4; + const numTiles = Math.max(1, Math.ceil(num_columns / tile_size)); + if (scan_tpb * scan_per_thread < numTiles) { + throw new Error(`smvp_v2_pair_tree: scan_tpb*scan_per_thread (${scan_tpb}*${scan_per_thread}=${scan_tpb * scan_per_thread}) must be >= num_tiles (${numTiles}) for num_columns=${num_columns}.`); } - const pipelines = await compileAll(device, shaderManager, wgi, s, tpb, per_thread); - const scratch = allocScratch(device, num_columns, input_size, s, max_levels, tpb, per_thread); + const pipelines = await compileAll(device, shaderManager, wgi, s, tile_tpb, tile_per_thread, scan_tpb, scan_per_thread); + const scratch = allocScratch(device, num_columns, input_size, s, max_levels, tile_size); const M = scratch.M; const padLeft = input_size; @@ -376,7 +448,9 @@ export async function runSmvpV2PairTree(opts: SmvpV2PairTreeOptions): Promise) { } `; +export const ba_planner_v2_mwg_local = `{{> structs }} + +// Multi-workgroup v2 planner — Pass 1 of 3: per-tile local scan. +// +// The single-workgroup ba_planner_v2_prod was limited by TPB * +// PER_THREAD >= B; production B = num_columns >= 32768 blows out +// register/shared limits. Splitting the planner into three passes lets +// each workgroup process a fixed TILE = TPB * PER_THREAD bucket window +// regardless of total B, with cross-tile fixup applied in pass 2. +// +// Per WG (wg_id): +// tile_start = wg_id * TILE +// tile_end = min(tile_start + TILE, B) +// For bucket b in [tile_start, tile_end): +// pc = counts[b] / 2, cf = counts[b] & 1, nc = pc + cf +// bucket_local_pair_off[b] = exclusive scan of pc within the tile +// bucket_local_carry_off[b] = exclusive scan of cf within the tile +// bucket_local_new_off[b] = exclusive scan of nc within the tile +// new_counts[b] = nc +// Last thread writes inclusive sums to: +// wg_totals[3*wg_id + 0] = total pairs in tile +// wg_totals[3*wg_id + 1] = total carries in tile +// wg_totals[3*wg_id + 2] = total new buckets in tile +// +// Pass 2 (ba_planner_v2_mwg_scan) consumes wg_totals to produce per-WG +// global starts; pass 3 (ba_planner_v2_mwg_scatter) reads +// bucket_local_*_off plus the global start to emit the per-bucket +// chunk_plan / scatter_plan / carry_plan entries. + +const TPB: u32 = {{ workgroup_size }}u; +const PER_THREAD: u32 = {{ per_thread }}u; +const TILE: u32 = TPB * PER_THREAD; + +@group(0) @binding(0) var counts: array; +@group(0) @binding(1) var bucket_local_pair_off: array; +@group(0) @binding(2) var bucket_local_carry_off: array; +@group(0) @binding(3) var bucket_local_new_off: array; +@group(0) @binding(4) var wg_totals: array; +@group(0) @binding(5) var new_counts: array; +@group(0) @binding(6) var params: vec4; +// params.x = B (num_columns) + +var pair_scan: array; +var carry_scan: array; +var new_scan: array; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(local_invocation_id) lid: vec3, + @builtin(workgroup_id) wgid: vec3) { + let tid = lid.x; + let wg_id = wgid.x; + let B = params.x; + let tile_start = wg_id * TILE; + + var local_pc: array; + var local_cf: array; + var local_nc: array; + var sum_p: u32 = 0u; + var sum_c: u32 = 0u; + var sum_n: u32 = 0u; + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tile_start + tid * PER_THREAD + k; + var pc: u32 = 0u; + var cf: u32 = 0u; + var nc: u32 = 0u; + if (b < B) { + let n = counts[b]; + pc = n / 2u; + cf = n & 1u; + nc = pc + cf; + } + local_pc[k] = pc; + local_cf[k] = cf; + local_nc[k] = nc; + sum_p += pc; + sum_c += cf; + sum_n += nc; + } + + pair_scan[tid] = sum_p; + carry_scan[tid] = sum_c; + new_scan[tid] = sum_n; + workgroupBarrier(); + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var add_p: u32 = 0u; + var add_c: u32 = 0u; + var add_n: u32 = 0u; + if (tid >= stride) { + add_p = pair_scan[tid - stride]; + add_c = carry_scan[tid - stride]; + add_n = new_scan[tid - stride]; + } + workgroupBarrier(); + if (tid >= stride) { + pair_scan[tid] = pair_scan[tid] + add_p; + carry_scan[tid] = carry_scan[tid] + add_c; + new_scan[tid] = new_scan[tid] + add_n; + } + workgroupBarrier(); + } + var local_pair_off: u32 = pair_scan[tid] - sum_p; + var local_carry_off: u32 = carry_scan[tid] - sum_c; + var local_new_off: u32 = new_scan[tid] - sum_n; + + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tile_start + tid * PER_THREAD + k; + if (b >= B) { break; } + bucket_local_pair_off[b] = local_pair_off; + bucket_local_carry_off[b] = local_carry_off; + bucket_local_new_off[b] = local_new_off; + new_counts[b] = local_nc[k]; + local_pair_off += local_pc[k]; + local_carry_off += local_cf[k]; + local_new_off += local_nc[k]; + } + + if (tid == TPB - 1u) { + wg_totals[3u * wg_id + 0u] = pair_scan[tid]; + wg_totals[3u * wg_id + 1u] = carry_scan[tid]; + wg_totals[3u * wg_id + 2u] = new_scan[tid]; + } + + {{{ recompile }}} +} +`; + +export const ba_planner_v2_mwg_scan = `{{> structs }} + +// Multi-workgroup v2 planner — Pass 2 of 3: cross-tile scan + totals + +// pad-fill. +// +// Runs as a single small workgroup that scans the per-WG inclusive sums +// emitted by pass 1 (ba_planner_v2_mwg_local) into per-WG exclusive +// global start offsets. Also emits totals[0..9] (grand totals + +// num_chunks + indirect-dispatch triples) and pad-fills the last partial +// chunk of chunk_plan / scatter_plan so the marshal / scatter prod +// kernels never read garbage indices on partial chunks. +// +// Layout of wg_totals: 3 u32 per WG. +// wg_totals[3*wg + 0] = pair count in WG (in) -> pair global start (out) +// wg_totals[3*wg + 1] = carry count in WG (in) -> carry global start (out) +// wg_totals[3*wg + 2] = new count in WG (in) -> new global start (out) +// +// Compile-time: +// TPB : workgroup size +// PER_TH : entries per thread (TPB * PER_TH must be >= num_wgs) +// S : chunk size in pairs +// WGI : downstream kernel workgroup size (must match marshal / +// disjoint / scatter / carry prod kernels) + +const TPB: u32 = {{ workgroup_size }}u; +const PER_TH: u32 = {{ per_thread }}u; +const S: u32 = {{ s }}u; +const WGI: u32 = {{ wgi }}u; + +@group(0) @binding(0) var wg_totals: array; +@group(0) @binding(1) var totals: array; +@group(0) @binding(2) var chunk_plan: array; +@group(0) @binding(3) var scatter_plan: array; +@group(0) @binding(4) var params: vec4; +// params.x = num_wgs +// params.y = pad_left_idx +// params.z = pad_right_idx +// params.w = discard_idx + +var pair_scan: array; +var carry_scan: array; +var new_scan: array; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; + let num_wgs = params.x; + + var local_p: array; + var local_c: array; + var local_n: array; + var sum_p: u32 = 0u; + var sum_c: u32 = 0u; + var sum_n: u32 = 0u; + for (var k: u32 = 0u; k < PER_TH; k = k + 1u) { + let w = tid * PER_TH + k; + var p: u32 = 0u; + var c: u32 = 0u; + var n: u32 = 0u; + if (w < num_wgs) { + p = wg_totals[3u * w + 0u]; + c = wg_totals[3u * w + 1u]; + n = wg_totals[3u * w + 2u]; + } + local_p[k] = p; + local_c[k] = c; + local_n[k] = n; + sum_p += p; + sum_c += c; + sum_n += n; + } + + pair_scan[tid] = sum_p; + carry_scan[tid] = sum_c; + new_scan[tid] = sum_n; + workgroupBarrier(); + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var add_p: u32 = 0u; + var add_c: u32 = 0u; + var add_n: u32 = 0u; + if (tid >= stride) { + add_p = pair_scan[tid - stride]; + add_c = carry_scan[tid - stride]; + add_n = new_scan[tid - stride]; + } + workgroupBarrier(); + if (tid >= stride) { + pair_scan[tid] = pair_scan[tid] + add_p; + carry_scan[tid] = carry_scan[tid] + add_c; + new_scan[tid] = new_scan[tid] + add_n; + } + workgroupBarrier(); + } + var off_p: u32 = pair_scan[tid] - sum_p; + var off_c: u32 = carry_scan[tid] - sum_c; + var off_n: u32 = new_scan[tid] - sum_n; + for (var k: u32 = 0u; k < PER_TH; k = k + 1u) { + let w = tid * PER_TH + k; + if (w >= num_wgs) { break; } + wg_totals[3u * w + 0u] = off_p; + wg_totals[3u * w + 1u] = off_c; + wg_totals[3u * w + 2u] = off_n; + off_p += local_p[k]; + off_c += local_c[k]; + off_n += local_n[k]; + } + + if (tid == TPB - 1u) { + let tp = pair_scan[tid]; + let tc = carry_scan[tid]; + let tn = new_scan[tid]; + totals[0] = tp; + totals[1] = tc; + totals[2] = tn; + let num_chunks = (tp + S - 1u) / S; + totals[3] = num_chunks; + totals[4] = (num_chunks + WGI - 1u) / WGI; + totals[5] = 1u; + totals[6] = 1u; + totals[7] = (tc + WGI - 1u) / WGI; + totals[8] = 1u; + totals[9] = 1u; + } + + workgroupBarrier(); + if (tid == 0u) { + let tp = pair_scan[TPB - 1u]; + let num_chunks = (tp + S - 1u) / S; + let pad_end = num_chunks * S; + let pad_left = params.y; + let pad_right = params.z; + let discard_idx = params.w; + for (var i: u32 = tp; i < pad_end; i = i + 1u) { + chunk_plan[2u * i + 0u] = pad_left; + chunk_plan[2u * i + 1u] = pad_right; + scatter_plan[i] = discard_idx; + } + } + + {{{ recompile }}} +} +`; + +export const ba_planner_v2_mwg_scatter = `{{> structs }} + +// Multi-workgroup v2 planner — Pass 3 of 3: per-tile scatter. +// +// One workgroup per TILE buckets (same launch shape as pass 1). Reads +// the per-bucket local prefix offsets from pass 1 plus the per-WG +// exclusive global starts from pass 2 to compute global pair / carry / +// new offsets per bucket, then writes: +// chunk_plan — pair-major operand indices into active_sums +// scatter_plan — per-pair destination index into next-level +// active_sums +// carry_plan — odd-count bucket carry-forward (src, dst) pairs +// new_offsets — per-bucket offset in next-level active_sums +// +// Compile-time: +// TPB : workgroup size (must match pass 1) +// PER_THREAD : buckets per thread (must match pass 1) +// PAIR_CAP : per-bucket pair-count bound (matches the single-WG +// planner — guards the inner emit loop so the WGSL +// compiler can const-bound it; pc is enforced separately) +// S : chunk size in pairs + +const TPB: u32 = {{ workgroup_size }}u; +const PER_THREAD: u32 = {{ per_thread }}u; +const PAIR_CAP: u32 = {{ pair_cap }}u; +const S: u32 = {{ s }}u; +const TILE: u32 = TPB * PER_THREAD; + +@group(0) @binding(0) var counts: array; +@group(0) @binding(1) var offsets: array; +@group(0) @binding(2) var bucket_local_pair_off: array; +@group(0) @binding(3) var bucket_local_carry_off: array; +@group(0) @binding(4) var bucket_local_new_off: array; +@group(0) @binding(5) var wg_totals: array; +@group(0) @binding(6) var chunk_plan: array; +@group(0) @binding(7) var scatter_plan: array; +@group(0) @binding(8) var carry_plan: array; +@group(0) @binding(9) var new_offsets: array; +@group(0) @binding(10) var params: vec4; +// params.x = B (num_columns) + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(local_invocation_id) lid: vec3, + @builtin(workgroup_id) wgid: vec3) { + let tid = lid.x; + let wg_id = wgid.x; + let B = params.x; + let tile_start = wg_id * TILE; + + let wg_global_pair_start = wg_totals[3u * wg_id + 0u]; + let wg_global_carry_start = wg_totals[3u * wg_id + 1u]; + let wg_global_new_start = wg_totals[3u * wg_id + 2u]; + + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tile_start + tid * PER_THREAD + k; + if (b >= B) { break; } + + let n = counts[b]; + let pc = n / 2u; + let cf = n & 1u; + let bucket_base = offsets[b]; + + let global_pair_off = wg_global_pair_start + bucket_local_pair_off[b]; + let global_carry_off = wg_global_carry_start + bucket_local_carry_off[b]; + let global_new_off = wg_global_new_start + bucket_local_new_off[b]; + + new_offsets[b] = global_new_off; + + for (var j: u32 = 0u; j < PAIR_CAP; j = j + 1u) { + if (j >= pc) { break; } + let global_slot = global_pair_off + j; + let chunk_id = global_slot / S; + let slot_in_chunk = global_slot % S; + let cp_base = 2u * (chunk_id * S + slot_in_chunk); + chunk_plan[cp_base + 0u] = bucket_base + 2u * j; + chunk_plan[cp_base + 1u] = bucket_base + 2u * j + 1u; + scatter_plan[chunk_id * S + slot_in_chunk] = global_new_off + j; + } + + if (cf != 0u) { + carry_plan[2u * global_carry_off + 0u] = bucket_base + n - 1u; + carry_plan[2u * global_carry_off + 1u] = global_new_off + pc; + } + } + + {{{ recompile }}} +} +`; + export const ba_planner_v2_prod = `{{> structs }} // Production GPU bin-packing planner for the v2 pair-tree integration. diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_local.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_local.template.wgsl new file mode 100644 index 000000000000..e7b9af0b0a4e --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_local.template.wgsl @@ -0,0 +1,124 @@ +{{> structs }} + +// Multi-workgroup v2 planner — Pass 1 of 3: per-tile local scan. +// +// The single-workgroup ba_planner_v2_prod was limited by TPB * +// PER_THREAD >= B; production B = num_columns >= 32768 blows out +// register/shared limits. Splitting the planner into three passes lets +// each workgroup process a fixed TILE = TPB * PER_THREAD bucket window +// regardless of total B, with cross-tile fixup applied in pass 2. +// +// Per WG (wg_id): +// tile_start = wg_id * TILE +// tile_end = min(tile_start + TILE, B) +// For bucket b in [tile_start, tile_end): +// pc = counts[b] / 2, cf = counts[b] & 1, nc = pc + cf +// bucket_local_pair_off[b] = exclusive scan of pc within the tile +// bucket_local_carry_off[b] = exclusive scan of cf within the tile +// bucket_local_new_off[b] = exclusive scan of nc within the tile +// new_counts[b] = nc +// Last thread writes inclusive sums to: +// wg_totals[3*wg_id + 0] = total pairs in tile +// wg_totals[3*wg_id + 1] = total carries in tile +// wg_totals[3*wg_id + 2] = total new buckets in tile +// +// Pass 2 (ba_planner_v2_mwg_scan) consumes wg_totals to produce per-WG +// global starts; pass 3 (ba_planner_v2_mwg_scatter) reads +// bucket_local_*_off plus the global start to emit the per-bucket +// chunk_plan / scatter_plan / carry_plan entries. + +const TPB: u32 = {{ workgroup_size }}u; +const PER_THREAD: u32 = {{ per_thread }}u; +const TILE: u32 = TPB * PER_THREAD; + +@group(0) @binding(0) var counts: array; +@group(0) @binding(1) var bucket_local_pair_off: array; +@group(0) @binding(2) var bucket_local_carry_off: array; +@group(0) @binding(3) var bucket_local_new_off: array; +@group(0) @binding(4) var wg_totals: array; +@group(0) @binding(5) var new_counts: array; +@group(0) @binding(6) var params: vec4; +// params.x = B (num_columns) + +var pair_scan: array; +var carry_scan: array; +var new_scan: array; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(local_invocation_id) lid: vec3, + @builtin(workgroup_id) wgid: vec3) { + let tid = lid.x; + let wg_id = wgid.x; + let B = params.x; + let tile_start = wg_id * TILE; + + var local_pc: array; + var local_cf: array; + var local_nc: array; + var sum_p: u32 = 0u; + var sum_c: u32 = 0u; + var sum_n: u32 = 0u; + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tile_start + tid * PER_THREAD + k; + var pc: u32 = 0u; + var cf: u32 = 0u; + var nc: u32 = 0u; + if (b < B) { + let n = counts[b]; + pc = n / 2u; + cf = n & 1u; + nc = pc + cf; + } + local_pc[k] = pc; + local_cf[k] = cf; + local_nc[k] = nc; + sum_p += pc; + sum_c += cf; + sum_n += nc; + } + + pair_scan[tid] = sum_p; + carry_scan[tid] = sum_c; + new_scan[tid] = sum_n; + workgroupBarrier(); + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var add_p: u32 = 0u; + var add_c: u32 = 0u; + var add_n: u32 = 0u; + if (tid >= stride) { + add_p = pair_scan[tid - stride]; + add_c = carry_scan[tid - stride]; + add_n = new_scan[tid - stride]; + } + workgroupBarrier(); + if (tid >= stride) { + pair_scan[tid] = pair_scan[tid] + add_p; + carry_scan[tid] = carry_scan[tid] + add_c; + new_scan[tid] = new_scan[tid] + add_n; + } + workgroupBarrier(); + } + var local_pair_off: u32 = pair_scan[tid] - sum_p; + var local_carry_off: u32 = carry_scan[tid] - sum_c; + var local_new_off: u32 = new_scan[tid] - sum_n; + + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tile_start + tid * PER_THREAD + k; + if (b >= B) { break; } + bucket_local_pair_off[b] = local_pair_off; + bucket_local_carry_off[b] = local_carry_off; + bucket_local_new_off[b] = local_new_off; + new_counts[b] = local_nc[k]; + local_pair_off += local_pc[k]; + local_carry_off += local_cf[k]; + local_new_off += local_nc[k]; + } + + if (tid == TPB - 1u) { + wg_totals[3u * wg_id + 0u] = pair_scan[tid]; + wg_totals[3u * wg_id + 1u] = carry_scan[tid]; + wg_totals[3u * wg_id + 2u] = new_scan[tid]; + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_scan.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_scan.template.wgsl new file mode 100644 index 000000000000..51d05b13384f --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_scan.template.wgsl @@ -0,0 +1,141 @@ +{{> structs }} + +// Multi-workgroup v2 planner — Pass 2 of 3: cross-tile scan + totals + +// pad-fill. +// +// Runs as a single small workgroup that scans the per-WG inclusive sums +// emitted by pass 1 (ba_planner_v2_mwg_local) into per-WG exclusive +// global start offsets. Also emits totals[0..9] (grand totals + +// num_chunks + indirect-dispatch triples) and pad-fills the last partial +// chunk of chunk_plan / scatter_plan so the marshal / scatter prod +// kernels never read garbage indices on partial chunks. +// +// Layout of wg_totals: 3 u32 per WG. +// wg_totals[3*wg + 0] = pair count in WG (in) -> pair global start (out) +// wg_totals[3*wg + 1] = carry count in WG (in) -> carry global start (out) +// wg_totals[3*wg + 2] = new count in WG (in) -> new global start (out) +// +// Compile-time: +// TPB : workgroup size +// PER_TH : entries per thread (TPB * PER_TH must be >= num_wgs) +// S : chunk size in pairs +// WGI : downstream kernel workgroup size (must match marshal / +// disjoint / scatter / carry prod kernels) + +const TPB: u32 = {{ workgroup_size }}u; +const PER_TH: u32 = {{ per_thread }}u; +const S: u32 = {{ s }}u; +const WGI: u32 = {{ wgi }}u; + +@group(0) @binding(0) var wg_totals: array; +@group(0) @binding(1) var totals: array; +@group(0) @binding(2) var chunk_plan: array; +@group(0) @binding(3) var scatter_plan: array; +@group(0) @binding(4) var params: vec4; +// params.x = num_wgs +// params.y = pad_left_idx +// params.z = pad_right_idx +// params.w = discard_idx + +var pair_scan: array; +var carry_scan: array; +var new_scan: array; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; + let num_wgs = params.x; + + var local_p: array; + var local_c: array; + var local_n: array; + var sum_p: u32 = 0u; + var sum_c: u32 = 0u; + var sum_n: u32 = 0u; + for (var k: u32 = 0u; k < PER_TH; k = k + 1u) { + let w = tid * PER_TH + k; + var p: u32 = 0u; + var c: u32 = 0u; + var n: u32 = 0u; + if (w < num_wgs) { + p = wg_totals[3u * w + 0u]; + c = wg_totals[3u * w + 1u]; + n = wg_totals[3u * w + 2u]; + } + local_p[k] = p; + local_c[k] = c; + local_n[k] = n; + sum_p += p; + sum_c += c; + sum_n += n; + } + + pair_scan[tid] = sum_p; + carry_scan[tid] = sum_c; + new_scan[tid] = sum_n; + workgroupBarrier(); + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var add_p: u32 = 0u; + var add_c: u32 = 0u; + var add_n: u32 = 0u; + if (tid >= stride) { + add_p = pair_scan[tid - stride]; + add_c = carry_scan[tid - stride]; + add_n = new_scan[tid - stride]; + } + workgroupBarrier(); + if (tid >= stride) { + pair_scan[tid] = pair_scan[tid] + add_p; + carry_scan[tid] = carry_scan[tid] + add_c; + new_scan[tid] = new_scan[tid] + add_n; + } + workgroupBarrier(); + } + var off_p: u32 = pair_scan[tid] - sum_p; + var off_c: u32 = carry_scan[tid] - sum_c; + var off_n: u32 = new_scan[tid] - sum_n; + for (var k: u32 = 0u; k < PER_TH; k = k + 1u) { + let w = tid * PER_TH + k; + if (w >= num_wgs) { break; } + wg_totals[3u * w + 0u] = off_p; + wg_totals[3u * w + 1u] = off_c; + wg_totals[3u * w + 2u] = off_n; + off_p += local_p[k]; + off_c += local_c[k]; + off_n += local_n[k]; + } + + if (tid == TPB - 1u) { + let tp = pair_scan[tid]; + let tc = carry_scan[tid]; + let tn = new_scan[tid]; + totals[0] = tp; + totals[1] = tc; + totals[2] = tn; + let num_chunks = (tp + S - 1u) / S; + totals[3] = num_chunks; + totals[4] = (num_chunks + WGI - 1u) / WGI; + totals[5] = 1u; + totals[6] = 1u; + totals[7] = (tc + WGI - 1u) / WGI; + totals[8] = 1u; + totals[9] = 1u; + } + + workgroupBarrier(); + if (tid == 0u) { + let tp = pair_scan[TPB - 1u]; + let num_chunks = (tp + S - 1u) / S; + let pad_end = num_chunks * S; + let pad_left = params.y; + let pad_right = params.z; + let discard_idx = params.w; + for (var i: u32 = tp; i < pad_end; i = i + 1u) { + chunk_plan[2u * i + 0u] = pad_left; + chunk_plan[2u * i + 1u] = pad_right; + scatter_plan[i] = discard_idx; + } + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_scatter.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_scatter.template.wgsl new file mode 100644 index 000000000000..b86c261f2113 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_scatter.template.wgsl @@ -0,0 +1,87 @@ +{{> structs }} + +// Multi-workgroup v2 planner — Pass 3 of 3: per-tile scatter. +// +// One workgroup per TILE buckets (same launch shape as pass 1). Reads +// the per-bucket local prefix offsets from pass 1 plus the per-WG +// exclusive global starts from pass 2 to compute global pair / carry / +// new offsets per bucket, then writes: +// chunk_plan — pair-major operand indices into active_sums +// scatter_plan — per-pair destination index into next-level +// active_sums +// carry_plan — odd-count bucket carry-forward (src, dst) pairs +// new_offsets — per-bucket offset in next-level active_sums +// +// Compile-time: +// TPB : workgroup size (must match pass 1) +// PER_THREAD : buckets per thread (must match pass 1) +// PAIR_CAP : per-bucket pair-count bound (matches the single-WG +// planner — guards the inner emit loop so the WGSL +// compiler can const-bound it; pc is enforced separately) +// S : chunk size in pairs + +const TPB: u32 = {{ workgroup_size }}u; +const PER_THREAD: u32 = {{ per_thread }}u; +const PAIR_CAP: u32 = {{ pair_cap }}u; +const S: u32 = {{ s }}u; +const TILE: u32 = TPB * PER_THREAD; + +@group(0) @binding(0) var counts: array; +@group(0) @binding(1) var offsets: array; +@group(0) @binding(2) var bucket_local_pair_off: array; +@group(0) @binding(3) var bucket_local_carry_off: array; +@group(0) @binding(4) var bucket_local_new_off: array; +@group(0) @binding(5) var wg_totals: array; +@group(0) @binding(6) var chunk_plan: array; +@group(0) @binding(7) var scatter_plan: array; +@group(0) @binding(8) var carry_plan: array; +@group(0) @binding(9) var new_offsets: array; +@group(0) @binding(10) var params: vec4; +// params.x = B (num_columns) + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(local_invocation_id) lid: vec3, + @builtin(workgroup_id) wgid: vec3) { + let tid = lid.x; + let wg_id = wgid.x; + let B = params.x; + let tile_start = wg_id * TILE; + + let wg_global_pair_start = wg_totals[3u * wg_id + 0u]; + let wg_global_carry_start = wg_totals[3u * wg_id + 1u]; + let wg_global_new_start = wg_totals[3u * wg_id + 2u]; + + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tile_start + tid * PER_THREAD + k; + if (b >= B) { break; } + + let n = counts[b]; + let pc = n / 2u; + let cf = n & 1u; + let bucket_base = offsets[b]; + + let global_pair_off = wg_global_pair_start + bucket_local_pair_off[b]; + let global_carry_off = wg_global_carry_start + bucket_local_carry_off[b]; + let global_new_off = wg_global_new_start + bucket_local_new_off[b]; + + new_offsets[b] = global_new_off; + + for (var j: u32 = 0u; j < PAIR_CAP; j = j + 1u) { + if (j >= pc) { break; } + let global_slot = global_pair_off + j; + let chunk_id = global_slot / S; + let slot_in_chunk = global_slot % S; + let cp_base = 2u * (chunk_id * S + slot_in_chunk); + chunk_plan[cp_base + 0u] = bucket_base + 2u * j; + chunk_plan[cp_base + 1u] = bucket_base + 2u * j + 1u; + scatter_plan[chunk_id * S + slot_in_chunk] = global_new_off + j; + } + + if (cf != 0u) { + carry_plan[2u * global_carry_off + 0u] = bucket_base + n - 1u; + carry_plan[2u * global_carry_off + 1u] = global_new_off + pc; + } + } + + {{{ recompile }}} +} From 7b0d12e38257b6470cdb809d39bc412a0be9bc4e Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 17:35:23 +0000 Subject: [PATCH 32/33] feat(bb/msm): add ?v2=1 path to bench-msm-e2e harness Adds a third A/B path alongside baseline + fused. ?v2=1 enables the use_v2_pair_tree route (multi-WG planner + indirect-dispatch downstream kernels) using packed bases. ?fused=0 disables the fused path so the harness can report baseline vs v2 alone. Results JSON now carries v2_median_ms / v2_samples_ms / v2_speedup / v2_sane for the BrowserStack runner to consume. --- .../ts/dev/msm-webgpu/bench-msm-e2e.ts | 89 +++++++++++++------ 1 file changed, 64 insertions(+), 25 deletions(-) diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-e2e.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-e2e.ts index 95bf2b245bf2..8ac52d31df98 100644 --- a/barretenberg/ts/dev/msm-webgpu/bench-msm-e2e.ts +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-e2e.ts @@ -52,27 +52,30 @@ function median(xs: number[]): number { return s[Math.floor(s.length / 2)]; } +type RunMode = 'baseline' | 'fused' | 'v2'; + async function runPath( ctx: GpuContext, bases: CachedBases, scalars: Uint8Array, - fused: boolean, + mode: RunMode, reps: number, ): Promise<{ ms: number[]; xy: { x: bigint; y: bigint } }> { - const tag = fused ? 'fused' : 'baseline'; - log(`[e2e] ${tag}: warm-up dispatch…`); + const fused = mode === 'fused'; + const useV2 = mode === 'v2'; + log(`[e2e] ${mode}: warm-up dispatch…`); await compute_bn254_msm_batch_affine( - ctx, bases, scalars as unknown as Buffer, false, {}, undefined, 'legacy', false, fused, + ctx, bases, scalars as unknown as Buffer, false, {}, undefined, 'legacy', false, fused, useV2, ); - log(`[e2e] ${tag}: warm-up ok`); + log(`[e2e] ${mode}: warm-up ok`); const ms: number[] = []; let xy = { x: 0n, y: 0n }; for (let r = 0; r < reps; r++) { const t0 = performance.now(); xy = await compute_bn254_msm_batch_affine( - ctx, bases, scalars as unknown as Buffer, false, {}, undefined, 'legacy', false, fused, + ctx, bases, scalars as unknown as Buffer, false, {}, undefined, 'legacy', false, fused, useV2, ); - log(`[e2e] ${tag}: rep ${r} = ${(performance.now() - t0).toFixed(1)} ms`); + log(`[e2e] ${mode}: rep ${r} = ${(performance.now() - t0).toFixed(1)} ms`); ms.push(performance.now() - t0); } return { ms, xy }; @@ -192,35 +195,66 @@ async function main() { if (qp.get('dump') === '1') (globalThis as unknown as { __msm_dump?: boolean }).__msm_dump = true; + // ?v2=1 enables the third path (use_v2_pair_tree). ?fused=0 disables + // the legacy packed-fused path so the harness reports just baseline + // vs v2 (useful for sizing N when fused alone isn't the comparand). + const runV2 = qp.get('v2') === '1'; + const runFused = qp.get('fused') !== '0'; + log('[e2e] running BASELINE (fused_revcarry=false, BigInt bases)…'); rc.postProgress({ kind: 'phase', phase: 'baseline_start' }); - const base = await runPath(ctx, basesBI, scalars, false, reps); + const base = await runPath(ctx, basesBI, scalars, 'baseline', reps); const baseMed = median(base.ms); log(`[e2e] baseline median ${baseMed.toFixed(2)} ms samples=[${base.ms.map(x => x.toFixed(1)).join(',')}]`); rc.postProgress({ kind: 'phase', phase: 'baseline_done', median_ms: baseMed }); - log('[e2e] running FUSED (fused_revcarry=true, packed bases)…'); - rc.postProgress({ kind: 'phase', phase: 'fused_start' }); - const fused = await runPath(ctx, basesPK, scalars, true, reps); - const fusedMed = median(fused.ms); - log(`[e2e] fused median ${fusedMed.toFixed(2)} ms samples=[${fused.ms.map(x => x.toFixed(1)).join(',')}]`); + let fused: { ms: number[]; xy: { x: bigint; y: bigint } } | null = null; + let fusedMed = 0; + if (runFused) { + log('[e2e] running FUSED (fused_revcarry=true, packed bases)…'); + rc.postProgress({ kind: 'phase', phase: 'fused_start' }); + fused = await runPath(ctx, basesPK, scalars, 'fused', reps); + fusedMed = median(fused.ms); + log(`[e2e] fused median ${fusedMed.toFixed(2)} ms samples=[${fused.ms.map(x => x.toFixed(1)).join(',')}]`); + } + + let v2: { ms: number[]; xy: { x: bigint; y: bigint } } | null = null; + let v2Med = 0; + if (runV2) { + log('[e2e] running V2 (use_v2_pair_tree=true, packed bases)…'); + rc.postProgress({ kind: 'phase', phase: 'v2_start' }); + v2 = await runPath(ctx, basesPK, scalars, 'v2', reps); + v2Med = median(v2.ms); + log(`[e2e] v2 median ${v2Med.toFixed(2)} ms samples=[${v2.ms.map(x => x.toFixed(1)).join(',')}]`); + rc.postProgress({ kind: 'phase', phase: 'v2_done', median_ms: v2Med }); + } - const sane = base.xy.x === fused.xy.x && base.xy.y === fused.xy.y; - log(`[e2e] correctness (baseline==fused): ${sane}`); + const fusedSane = fused !== null && base.xy.x === fused.xy.x && base.xy.y === fused.xy.y; + const v2Sane = v2 !== null && base.xy.x === v2.xy.x && base.xy.y === v2.xy.y; + if (fused) log(`[e2e] correctness (baseline==fused): ${fusedSane}`); + if (v2) log(`[e2e] correctness (baseline==v2): ${v2Sane}`); if (refAff) { const baseOk = base.xy.x === refAff.x && base.xy.y === refAff.y; - const fusedOk = fused.xy.x === refAff.x && fused.xy.y === refAff.y; - log(`[e2e] vs CPU oracle: baseline=${baseOk} fused=${fusedOk}`); + log(`[e2e] vs CPU oracle: baseline=${baseOk}`); + if (fused) log(`[e2e] vs CPU oracle: fused=${fused.xy.x === refAff.x && fused.xy.y === refAff.y}`); + if (v2) log(`[e2e] vs CPU oracle: v2=${v2.xy.x === refAff.x && v2.xy.y === refAff.y}`); log(`[e2e] oracle.xy = (${refAff.x}, ${refAff.y})`); } log(`[e2e] baseline.xy = (${base.xy.x}, ${base.xy.y})`); - log(`[e2e] fused.xy = (${fused.xy.x}, ${fused.xy.y})`); - const speedup = baseMed / fusedMed; - log(`[e2e] RESULT: baseline ${baseMed.toFixed(2)} ms fused ${fusedMed.toFixed(2)} ms speedup ${speedup.toFixed(3)}x sane=${sane}`); + if (fused) log(`[e2e] fused.xy = (${fused.xy.x}, ${fused.xy.y})`); + if (v2) log(`[e2e] v2.xy = (${v2.xy.x}, ${v2.xy.y})`); + const fusedSpeedup = fused ? baseMed / fusedMed : 0; + const v2Speedup = v2 ? baseMed / v2Med : 0; + log( + `[e2e] RESULT: baseline ${baseMed.toFixed(2)} ms` + + (fused ? ` fused ${fusedMed.toFixed(2)} ms (x${fusedSpeedup.toFixed(3)})` : '') + + (v2 ? ` v2 ${v2Med.toFixed(2)} ms (x${v2Speedup.toFixed(3)})` : '') + + ` fusedSane=${fusedSane} v2Sane=${v2Sane}`, + ); await rc.postResults({ state: 'done', - params: { logN, n, reps }, + params: { logN, n, reps, fused: runFused, v2: runV2 }, results: [ { name: 'msm_e2e', @@ -229,12 +263,17 @@ async function main() { reps, baseline_median_ms: baseMed, fused_median_ms: fusedMed, + v2_median_ms: v2Med, baseline_samples_ms: base.ms, - fused_samples_ms: fused.ms, - speedup, - sanity_ok: sane, + fused_samples_ms: fused?.ms ?? [], + v2_samples_ms: v2?.ms ?? [], + fused_speedup: fusedSpeedup, + v2_speedup: v2Speedup, + fused_sane: fusedSane, + v2_sane: v2Sane, baseline_xy: { x: base.xy.x.toString(), y: base.xy.y.toString() }, - fused_xy: { x: fused.xy.x.toString(), y: fused.xy.y.toString() }, + fused_xy: fused ? { x: fused.xy.x.toString(), y: fused.xy.y.toString() } : null, + v2_xy: v2 ? { x: v2.xy.x.toString(), y: v2.xy.y.toString() } : null, }, ], }); From e246a9f77a2dc9996d7635f1213d0c3302374c80 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Wed, 20 May 2026 18:22:33 +0000 Subject: [PATCH 33/33] fix(bb/msm): drop unused {{> structs }} partial from mwg planner kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The three multi-WG planner passes don't reference BigInt — they only read/write u32 arrays — but pulled in the structs partial which expanded to a malformed BigInt declaration (empty num_words) when the gen methods didn't pass num_words. Removing the partial fixes M2 WGSL compilation. Also threads logn/v2/fused/cpuoracle/oracle1 query params through the BS runner so bench-msm-e2e can be driven from run-browserstack.mjs. --- .../ts/dev/msm-webgpu/scripts/run-browserstack.mjs | 11 +++++++++++ .../ts/src/msm_webgpu/wgsl/_generated/shaders.ts | 12 +++--------- .../wgsl/cuzk/ba_planner_v2_mwg_local.template.wgsl | 2 -- .../wgsl/cuzk/ba_planner_v2_mwg_scan.template.wgsl | 2 -- .../cuzk/ba_planner_v2_mwg_scatter.template.wgsl | 2 -- 5 files changed, 14 insertions(+), 15 deletions(-) diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs index 4bb7c19a9751..b4201d34a1b9 100644 --- a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs +++ b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs @@ -50,6 +50,11 @@ const { values: argv } = parseArgs({ buckets: { type: "string" }, seed: { type: "string" }, skew: { type: "string" }, + logn: { type: "string" }, + v2: { type: "string" }, + fused: { type: "string" }, + cpuoracle: { type: "string" }, + oracle1: { type: "string" }, port: { type: "string", default: "5198" }, "first-progress-ms": { type: "string" }, "stall-ms": { type: "string", default: "180000" }, @@ -139,6 +144,7 @@ const pageMap = { "bench-msm-oracle": "/dev/msm-webgpu/bench-msm-oracle.html", "bench-msm-oracle-prod": "/dev/msm-webgpu/bench-msm-oracle-prod.html", "bench-smvp-tree": "/dev/msm-webgpu/bench-smvp-tree.html", + "bench-msm-e2e": "/dev/msm-webgpu/bench-msm-e2e.html", sanity: "/dev/msm-webgpu/index.html", }; if (!pageMap[argv.page]) { @@ -428,6 +434,11 @@ async function main() { if (argv.buckets) qp.set("buckets", String(argv.buckets)); if (argv.seed) qp.set("seed", String(argv.seed)); if (argv.skew) qp.set("skew", String(argv.skew)); + if (argv.logn) qp.set("logn", String(argv.logn)); + if (argv.v2) qp.set("v2", String(argv.v2)); + if (argv.fused) qp.set("fused", String(argv.fused)); + if (argv.cpuoracle) qp.set("cpuoracle", String(argv.cpuoracle)); + if (argv.oracle1) qp.set("oracle1", String(argv.oracle1)); const pageUrl = `${baseUrl}${pageMap[argv.page]}?${qp.toString()}`; err(`page URL: ${pageUrl}`); diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index 674d4b0b9780..dd9cf1647044 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -2618,9 +2618,7 @@ fn main(@builtin(local_invocation_id) lid: vec3) { } `; -export const ba_planner_v2_mwg_local = `{{> structs }} - -// Multi-workgroup v2 planner — Pass 1 of 3: per-tile local scan. +export const ba_planner_v2_mwg_local = `// Multi-workgroup v2 planner — Pass 1 of 3: per-tile local scan. // // The single-workgroup ba_planner_v2_prod was limited by TPB * // PER_THREAD >= B; production B = num_columns >= 32768 blows out @@ -2744,9 +2742,7 @@ fn main(@builtin(local_invocation_id) lid: vec3, } `; -export const ba_planner_v2_mwg_scan = `{{> structs }} - -// Multi-workgroup v2 planner — Pass 2 of 3: cross-tile scan + totals + +export const ba_planner_v2_mwg_scan = `// Multi-workgroup v2 planner — Pass 2 of 3: cross-tile scan + totals + // pad-fill. // // Runs as a single small workgroup that scans the per-WG inclusive sums @@ -2887,9 +2883,7 @@ fn main(@builtin(local_invocation_id) lid: vec3) { } `; -export const ba_planner_v2_mwg_scatter = `{{> structs }} - -// Multi-workgroup v2 planner — Pass 3 of 3: per-tile scatter. +export const ba_planner_v2_mwg_scatter = `// Multi-workgroup v2 planner — Pass 3 of 3: per-tile scatter. // // One workgroup per TILE buckets (same launch shape as pass 1). Reads // the per-bucket local prefix offsets from pass 1 plus the per-WG diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_local.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_local.template.wgsl index e7b9af0b0a4e..da814011b5e7 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_local.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_local.template.wgsl @@ -1,5 +1,3 @@ -{{> structs }} - // Multi-workgroup v2 planner — Pass 1 of 3: per-tile local scan. // // The single-workgroup ba_planner_v2_prod was limited by TPB * diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_scan.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_scan.template.wgsl index 51d05b13384f..8d899ef29579 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_scan.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_scan.template.wgsl @@ -1,5 +1,3 @@ -{{> structs }} - // Multi-workgroup v2 planner — Pass 2 of 3: cross-tile scan + totals + // pad-fill. // diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_scatter.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_scatter.template.wgsl index b86c261f2113..7dbcea8b24ed 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_scatter.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_mwg_scatter.template.wgsl @@ -1,5 +1,3 @@ -{{> structs }} - // Multi-workgroup v2 planner — Pass 3 of 3: per-tile scatter. // // One workgroup per TILE buckets (same launch shape as pass 1). Reads