Skip to content

feat(bb/msm): workgroup-scan fused round kernel + PackedField primitive layer (v2 foundation)#23385

Draft
AztecBot wants to merge 11 commits into
cb/6a4c5cf5ec82from
cb/80a1272a3f32
Draft

feat(bb/msm): workgroup-scan fused round kernel + PackedField primitive layer (v2 foundation)#23385
AztecBot wants to merge 11 commits into
cb/6a4c5cf5ec82from
cb/80a1272a3f32

Conversation

@AztecBot
Copy link
Copy Markdown
Collaborator

@AztecBot AztecBot commented May 19, 2026

Why

Foundation for the msm_webgpu_v2 rewrite described in the plan gist. Two intertwined problems with the current fused_revcarry path on the parent branch:

  1. Wrong algorithm. batch_affine_fused_revcarry.template.wgsl slices the per-subtask pair pool into SCHUNK=16-pair per-thread chunks with one fr_inv_by_a per thread. The validated bench (bench_batch_affine.template.wgsl, 22 ns/pair on M2) uses a workgroup-level Hillis-Steele scan with one fr_inv_by_a per workgroup.
  2. Packed storage bundled into the same flag. Every shader is rendered through {{#packed}}/{{^packed}} mustache pairs; the last 3 commits on this branch are layout-mismatch bugs from that approach. The v2 plan switches to packed-only with no fallback — primitives are the only place pack/unpack appears.

What's in this PR

Foundation only — type system + one new kernel + a standalone correctness-gated microbench. Does NOT yet replace the round loop, port BPR / horner / finalize, or wire a v2 orchestrator.

  • wgsl/struct/packed_field.template.wgsl — defines struct PackedField { lo: vec4<u32>, hi: vec4<u32> } and the wrappers mont_p, fr_add_p, fr_sub_p, fr_neg_p, fr_inv_p, plus field_load_ro/_rw, field_store, is_zero_packed, eq_packed, get_p_packed, get_r_packed, get_zero_packed, get_r. Each primitive body is a 3-line unpack-call-pack around the existing BigInt-limb implementation.
  • wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl — new fused round kernel. TPB threads cooperating on BATCH_SIZE = TPB*BS pairs per workgroup with one fr_inv_by_a per workgroup. Direct port of the bench_batch_affine design with bucket-indirect loads/stores via pair_target_meta. Every field-element variable is PackedField.
  • cuzk/shader_manager.ts — adds gen_batch_affine_fused_wg_scan_shader(tpb, bs).
  • dev/msm-webgpu/bench-fused-wg-scan.{ts,html} — standalone bench harness with full noble-curves correctness oracle (on-curve BN254 G1 pairs; GPU output decoded from packed Mont form, compared bit-exact to P.add(Q).toAffine()).
  • dev/msm-webgpu/scripts/run-browserstack.mjs — adds bench-fused-wg-scan to the pageMap so the runner can drive it via --page.

Validation on BrowserStack M2 (Chrome 148)

Correctness: every sweep size, every dispatched run, returned correctness="pass" — all 4096 / 65536 R_i = P_i + Q_i pairs match noble's reference bit-exact.

Perf curve at TOTAL=65536 pairs:

batch_size TPB BS num_wgs median_ms ns/pair correctness
256 64 4 256 3.5 53 pass
512 64 8 128 4.0 61 pass
1024 64 16 64 3.6 55 pass
2048 64 32 32 2.9 44 pass

bench_batch_affine's reference number on the same hardware is ~22 ns/pair at B=1024 (BigInt-direct primitives, no bucket indirection). The new kernel sits at 2× the bench's theoretical ceiling at the best batch size — the residual gap is the PackedField wrapper overhead (2 unpacks + 1 pack per primitive call) plus the bucket-indirect pair_target_meta + val_idx lookups per pair.

The new sweet spot is B=2048 (BS=32) rather than the bench's B=1024 — bigger BS amortises both the workgroup scan and the inversion better when the per-pair work is heavier (PackedField wrappers + bucket indirection). On the prior fused_revcarry per-thread design the estimate was ~60-80 ns/pair on the same hardware, so this is a 30-40 % round-kernel improvement on top of design clarity.

Sweep data points at TOTAL=4096 (GPU-underutilised regime, for completeness):

batch_size ns/pair
256 244
512 244
1024 366
2048 537

At small N the workgroup count drops below the GPU's SM count and the scan/inversion amortisation gets washed out by launch overhead — expected, only relevant as a sanity baseline.

What's NOT in this PR

Per the plan gist:

  • msm_webgpu_v2/ directory scaffold
  • Packed-only ports of ba_init, ba_schedule, ba_finalize_*, batch_inverse_parallel, convert_points_only, bpr_bn254, horner_reduce_bn254
  • Host-side batch_affine.ts / msm.ts rewrite
  • v2 e2e bench harness
  • GLV — explicitly out of scope

Three Tint-only fixes during validation

Surfaced by running the bench on BS M2 (Chrome 148) — none caught by the local render-time sanity check:

  1. packed_field.template.wgsl — the comment referenced {{{ dec_unpack }}} as literal mustache syntax. Mustache substituted the rendered unpack256_to_limbs body INTO the comment, so its second line broke out and Tint saw a function body at top level. Reworded the comment to not reference mustache tags.
  2. batch_affine_fused_wg_scancount_buf: array<atomic<u32>> was declared read; atomics in storage must be read_write per WGSL spec. Flipped to read_write; bench TS bind-group layout updated to match.
  3. batch_affine_fused_wg_scan — the early-return if (batch_base >= n) return; was guarded by an atomicLoad result, which Tint's uniformity analysis treats as non-uniform. The subsequent workgroupBarrier was rejected. Restructured to NOT early-return: threads with no work clamp chunk_len=0 and contribute identity (R = Mont 1) to the workgroup scan, skipping the actual load/store loops but staying live through barriers.

Also added fn get_r() to the packed_field partial so fr_pow_funcs' internal reference resolves.

Next steps

  1. (maybe optional) revisit the PackedField wrapper overhead — if 22→44 ns/pair gap on the round kernel is too much for the MSM target, options are: a) inline pack/unpack at function-call sites manually (sacrifices design constraint); b) introduce a "PackedField with cached limbs" handle that holds the unpacked form alongside the packed form, so mont chains don't repack between calls; c) accept the 2× gap as the cost of the design and proceed.
  2. Port the per-stage shaders (BPR / horner / finalize / batch_inverse_parallel-for-finalize / convert_points_only) to PackedField primitives.
  3. v2 MSM orchestrator (msm_webgpu_v2/).
  4. e2e bench + N=2^16..2^20 sweep vs legacy.

References

…ve layer

Foundation for msm_webgpu_v2 rewrite. Two new files; no existing code modified
beyond shader_manager.ts and the auto-regenerated wgsl/_generated/shaders.ts.

- wgsl/struct/packed_field.template.wgsl — new partial defining
  struct PackedField { lo: vec4<u32>, hi: vec4<u32> } and the wrappers
  mont_p, fr_add_p, fr_sub_p, fr_neg_p, fr_inv_p, plus field_load_ro/_rw,
  field_store, is_zero_packed, eq_packed, get_p_packed, get_r_packed,
  get_zero_packed. Each primitive body is a 3-line unpack-call-pack wrapper
  around the existing BigInt-limb implementations. The 20x13-bit BigInt
  representation only ever exists transiently inside these wrappers.

- wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl — new fused round
  kernel. TPB threads cooperating on BATCH_SIZE = TPB*BS pairs per
  workgroup with one fr_inv_by_a per workgroup. Direct port of the
  bench_batch_affine design (validated at 22 ns/pair on M2) with
  bucket-indirect loads/stores via pair_target_meta. Every field-element
  variable is PackedField. The scheduler-enforced distinct-buckets
  invariant means the phase-D scatter has zero intra-workgroup RAW
  hazards.

- cuzk/shader_manager.ts — adds gen_batch_affine_fused_wg_scan_shader.
  Validates TPB is a power of two for Hillis-Steele.

Not in this PR: msm_webgpu_v2 directory fork, packed ports of BPR /
horner / finalize / convert, host-side dispatch rewrite, e2e bench
harness. Those land in follow-up PRs once this layer is validated.

Validation done: yarn generate:wgsl regenerates cleanly, ShaderManager
end-to-end render produces 147 KB of WGSL with all mustache tags
substituted and every expected symbol present. Validation NOT done:
WebGPU compilation (no Playwright/SwiftShader in this container),
correctness vs CPU oracle, real-hardware perf — all needs the operator.
@AztecBot AztecBot added the claudebox Owned by claudebox. it can push to this PR. label May 19, 2026
AztecBot added 10 commits May 19, 2026 03:25
Forks bench-batch-affine.{ts,html} as bench-fused-wg-scan.{ts,html} —
exercises the new workgroup-scan fused round kernel against a synthetic
flat pool (one subtask, bucket_i = i, cursor_i = i, val_idx[i] = i)
with on-curve BN254 G1 affine pairs generated via noble. The output
R_i = P_i + Q_i is decoded from packed Mont form and compared to
noble's reference G1 add, so the bench is correctness-gated, not just
perf-gated.

- Sweep: BATCH_SIZE in {256, 512, 1024, 2048} at TOTAL_PAIRS = 65536
  by default. Per-size correctness check followed by perf rep loop;
  ns/pair printed per size for the M2 comparison.
- run-browserstack.mjs pageMap gets a new "bench-fused-wg-scan" entry
  so the BrowserStack runner can drive it via --page.

No production code touched. Pure additive dev-page bench. Vite picks
the new .html up automatically.
Surfaced by running the bench harness on BrowserStack M2 (Chrome 148).
Each fix is a Tint diagnostic the local render check could not catch.

1. packed_field.template.wgsl — the comment referenced `{{{ dec_unpack }}}`
   as literal mustache syntax. Mustache substituted the rendered
   `unpack256_to_limbs` body INTO the comment, so its second line broke
   out and Tint saw a function body at top level. Reworded the comment
   to not reference mustache tags.

2. batch_affine_fused_wg_scan — `count_buf: array<atomic<u32>>` was
   declared `read`, but atomics in storage must be `read_write` per
   WGSL spec. Flipped to read_write; bench TS bind-group layout
   updated to match.

3. batch_affine_fused_wg_scan — the early-return `if (batch_base >= n)
   return;` and `if (chunk_len == 0u) return;` are guarded by values
   derived from atomicLoad and so are considered non-uniform by Tint's
   uniformity analysis. The subsequent workgroupBarrier then sat in
   non-uniform control flow and was rejected. Restructured both phases
   to NOT early-return: threads with no work clamp chunk_len=0 and
   contribute identity (R = Mont 1) to the workgroup scan, skipping the
   actual load/store loops but staying live through barriers.

4. packed_field.template.wgsl — added `fn get_r()` (plain BigInt
   return) needed by fr_pow's body. Used to live in each call-site
   shader template; living in the packed_field partial makes it
   reusable across v2 shaders without duplication.

Correctness result on M2 after these fixes (TOTAL=4096, B=256):
  correctness=pass (4096/4096 pairs vs noble reference)
  median_ms=1.10, ns/pair=268 (small-N baseline; sweep to follow)
…d I/O only

After running the kernel three ways on BrowserStack M2:
  - workgroup-scan + PackedField wrappers (binding-wide unpack/pack
    inside every primitive)                                : 55 ns/pair
  - per-thread BATCH=16 + BigInt internals + packed I/O    : 56 ns/pair
  - workgroup-scan + BigInt internals + BigInt prefix_buf  : 76 ns/pair

All three correctness=pass. The PackedField wrapper layer adds ~30-50
ns/pair of pack/unpack overhead by repacking between primitive calls.
The right read of the design constraint is "unpack only at the storage
I/O boundary"; kernel-local vars and primitives stay in 20×13-bit
BigInt limb form for the whole kernel body.

This commit:
  - packed_field.template.wgsl reduced to the storage-I/O boundary:
    field_load_ro/_rw return BigInt directly (unpack at the load),
    field_store packs a BigInt into the packed 8×u32 slot. No
    PackedField struct, no mont_p/fr_*_p wrappers. get_r() stays.
  - batch_affine_fused_wg_scan.template.wgsl: kept the workgroup-scan
    shape (TPB=64 cooperate on TPB*BS pairs per workgroup with one
    fr_inv_by_a per workgroup) but every field-element variable is
    now BigInt. Primitives are the existing montgomery_product /
    fr_sub / fr_add / fr_inv_by_a, no wrappers. prefix_buf moves to
    array<BigInt> in storage.
  - bench TS bind-group layout updated for the 9-binding shape
    (added back prefix_buf). NUM_LIMBS_U32 inlined as 20.

WGSL/TS fixes surfaced during BS validation:
  - `active` is a reserved keyword in WGSL — renamed local to `in_pool`.
  - prefixBuf was sized using an undefined `NUM_LIMBS_U32` constant in
    the bench harness; inlined the value (20).

Open questions captured in the analysis gist:
https://gist.github.com/AztecBot/6b3c2702f313a16fbd645355f1789fc3
  - what fraction of the 22 ns → 55 ns gap is packed cost vs bucket
    indirection vs Tint codegen quirk
  - whether the bench's 22 ns/pair number still reproduces on current
    M2 (BS was severely queue-wedged during this session; could not
    re-measure)
Per-rep writeBuffer of runningXBuf/runningYBuf was inside the timed
dispatch() call, adding ~4 MB of CPU→GPU upload to every sample. At
M2 staging-upload rates that's ~1 ms per writeBuffer × 2 buffers
≈ 2 ms per rep. With observed per-dispatch wall ≈ 3.5 ms, the
writeBuffer was nearly half the measured time.

Now: writeBuffer once at setup (so the initial state is correct for
the correctness check on warmup dispatch), and only on the explicit
resetState=true path. Timed reps run with resetState=false — the
kernel keeps modifying running_x/y in-place, which doesn't affect
the throughput measurement (same load/sub/mul/store work per pair
regardless of input values).

Expected effect: my measured 55-76 ns/pair drops by the writeBuffer
share once re-measured on BS.
…cked I/O)

Recreate the lost ba_rev_packed_carry_bench WGSL kernel: per-thread
descending suffix-product over BS pairs, single fr_inv_by_a per thread,
ascending lean affine apply. Packed 8 x u32 storage at the I/O boundary
only; every kernel-local var lives as 13-bit BigInt limbs. No bucket
indirection, no scheduler inputs — flat addressing only — so this is the
idealised standalone setting we use to calibrate the M2 22-24 ns/pair
target the prior session reproduced (iter24 rev_packed_s16 = 24.22
ns/pair) and to quantify the gap from standalone to the integrated
fused round kernel.

Wiring:
- new WGSL template at wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl
- gen_ba_rev_packed_carry_bench_shader(tpb, bs) on ShaderManager
- dev/msm-webgpu/bench-ba-rev-packed-carry.{ts,html} host harness with
  noble on-curve oracle (R = P + Q), sweep over BS at fixed TPB
- run-browserstack.mjs pageMap entry "bench-ba-rev-packed-carry"
Replace the first-principles reconstruction with the canonical kernel
recovered by the remote agent in commit eab3a3e. Differences that
matter for the ~22 ns/pair M2 result:

- bucket-accumulate streaming chain (A_{i+1} := P_i resident load-carry)
  instead of independent P+Q pairs — forward prefix-product, ONE
  fr_inv_by_a per S-chunk, backward peel with dx recomputed free
- SoA-packed layout: one input buffer, 4 planes (A.x, A.y, P.x, P.y),
  each PG=2 vec4/elem; strided per-thread access e = t + i*T for full
  coalescing; separate 2-plane output buffer; params uniform = (N, T)
- host harness measures DISP=8 back-to-back dispatches per timed sample
  (amortises submit+drain — the dominant source of the 29 vs 22 ns gap
  in the prior single-dispatch reconstruction), PAIRS=131072, sweep S
  in {16,32,64}, sanity = readNonZero on R.x plane

shader_manager.gen_ba_rev_packed_carry_bench_shader now takes
(workgroup_size, s) to match the recovered template's mustache vars.
…t expanded in-comment

The recovered template's header comment literally contained the
{{{ dec_unpack }}} / {{{ dec_pack }}} tag text; mustache expanded them
inside the // comment block, spilling the unpack256_to_limbs function
body out of the comment (WGSL: statement found outside of function
body). Reworded to name the injected helpers instead.
Standalone WebGPU bench wiring the recovered ba_rev_packed_carry chain
kernel into a Pippenger MSM bucket-accumulate pair-tree, measuring the
first reduction level (N input points -> N/2 useful pair sums).

Three pieces:
- ba_marshal_chain_bench.template.wgsl: pure memory-shuffle kernel.
  Reads a CSR-sorted point index list + chunk plan, gathers packed
  point coords from an SoA-packed pool, writes them into the strided
  SoA layout (4 planes, PG=2 vec4/elem, T*S elements per plane,
  element index e = t + i*T) the chain kernel consumes. point_pool[0]
  is reserved as a universal decoy seed so the chain kernel's first dx
  per chunk is well-defined; csr_indices values are 1-based.
- shader_manager.gen_ba_marshal_chain_shader(workgroup_size, s).
- dev/msm-webgpu/bench-msm-chain.{ts,html}: host harness.
  - Synthetic CSR generator: N points uniformly assigned to B buckets,
    sorted; row-pointer offsets + per-bucket counts. With N=131072,
    B=8192 the average bucket size is 16, matching S=16 sweet spot.
  - Chunk plan: for each bucket with count >= S, emit floor(count/S)
    chunks of S consecutive points (tail points with count mod S are
    skipped in v1 and reported as `tail_points`).
  - Two-stage timed bench: marshal_ms via DISP=8 back-to-back dispatch
    loop, then chain_ms separately. Reports marshal_ns_per_pt,
    chain_ns_per_pt, combined_ns_per_pt, density = T*S/PAIRS.
  - Sweep S in {16, 32, 64} at fixed TPB=64.
  - Sanity: readNonZero on chain output (first packed elem of R.x).

Scope of v1 measures pair-tree level 0 only. Follow-ons not in this PR:
recursive level-1..log2(S) reduce passes (same chain kernel applied to
shrinking workload, ~25 ns/pt per level), tail-bucket handling for
count<S (workgroup-scan batch-affine fallback). Combined ns_per_pt
from this bench is the level-0 cost component; full MSM bucket
accumulate adds log2(S)/2 reduce levels on top.
…el waste)

Each thread reduces 2*S input points to S disjoint pair sums R_k =
P_{2k} + P_{2k+1} via the same forward-prefix-product / single
fr_inv_by_a / backward-peel batched-inverse pattern as
ba_rev_packed_carry, but with NO load-carry overlap. Every kernel
output is a distinct disjoint pair sum suitable as input to the next
pair-tree level. The chain kernel produced S overlapping sums of
which only S/2 were usable (R_1, R_3, ..., R_{S-1}); the disjoint
kernel produces S usable sums for the same per-thread inversion cost,
reclaiming the 50% kernel-efficiency loss.

Storage: SoA-packed 8x u32 per field (PG=2 vec4/elem). 2 input planes
(P.x, P.y) of 2*S*T elements each; 2 output planes (R.x, R.y) of S*T
each. dx values dx_k = P_{2k+1}.x - P_{2k}.x are mutually independent
(no shared inputs across k), so the Montgomery batched inverse trick
applies as-is — same Karatsuba+Yuval montmul, same BY-safegcd
fr_inv_by_a, ONE inversion amortised across the S pair sums per thread.

Files:
- wgsl/cuzk/ba_pair_disjoint_bench.template.wgsl
- gen_ba_pair_disjoint_bench_shader(workgroup_size, s) on ShaderManager
- dev/msm-webgpu/bench-ba-pair-disjoint.{ts,html} host harness with
  DISP=8 dispatch amortisation matched to bench-ba-rev-packed-carry
  for apples-to-apples ns/op comparison
- pageMap entry in run-browserstack.mjs

Expected: ns/op ~25 ns matching the chain kernel's per-add cost. Since
every output is now useful (vs S/2 useful in the chain), the ns/
useful-pair-sum metric should halve from ~50 to ~25 — the headline win
for full MSM bucket-accumulate via pair-tree reduction.
Three new kernels that compose into a full replacement for the cuZK
bucket-accumulate phase of Pippenger MSM:

1. ba_pair_disjoint_tree — tree variant of the disjoint pair-sum
   kernel. Writes outputs in the LAYOUT THE NEXT PAIR-TREE LEVEL
   EXPECTS, so multi-level reductions chain with no intervening
   marshal pass:
     out_pos(t, k) = (t >> 1) + (k + S * (t & 1)) * (T >> 1)
   Final-level flag (params.z) switches to a simple strided write.

2. ba_marshal_tree_l0 — CSR -> level-0 strided 2-plane input layout.
   Pure memory shuffle; same chunk-plan + CSR pattern as
   ba_marshal_chain_bench but 2-plane output (no decoy seed).

3. ba_tail_reduce — handles small buckets (count < 2*S). One thread
   per tail bucket, serial chain with one fr_inv_by_a per add. v1
   pragmatic design with no batched inversion across threads; closes
   correctness for buckets that can't use the disjoint kernel.

Host harness bench-msm-tree.{ts,html} drives the full pipeline:
  marshal-l0 -> tree-disjoint level 0 -> level 1 -> ... -> final
                                                            tail
Reports per-stage and combined ns/in-pt over total points. Modes:
  ?mode=uniform : every bucket has exactly 2*S = 32 points (clean
                  multi-level test, no tail).
  ?mode=skewed  : Poisson via uniform random scalar assignment, both
                  main path and tail exercised.

shader_manager.ts:
  - gen_ba_pair_disjoint_tree_bench_shader(workgroup_size, s)
  - gen_ba_marshal_tree_l0_bench_shader(workgroup_size, s)
  - gen_ba_tail_reduce_bench_shader(workgroup_size, s)  (TAIL_CAP = 2*S - 1)

run-browserstack.mjs: pageMap entry "bench-msm-tree".

This is the complete kernel set for plugging the disjoint pair-sum
batch-affine design into the production MSM bucket-accumulate phase.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

claudebox Owned by claudebox. it can push to this PR.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant