feat(bb/msm): workgroup-scan fused round kernel + PackedField primitive layer (v2 foundation)#23385
Draft
AztecBot wants to merge 11 commits into
Draft
feat(bb/msm): workgroup-scan fused round kernel + PackedField primitive layer (v2 foundation)#23385AztecBot wants to merge 11 commits into
AztecBot wants to merge 11 commits into
Conversation
…ve layer
Foundation for msm_webgpu_v2 rewrite. Two new files; no existing code modified
beyond shader_manager.ts and the auto-regenerated wgsl/_generated/shaders.ts.
- wgsl/struct/packed_field.template.wgsl — new partial defining
struct PackedField { lo: vec4<u32>, hi: vec4<u32> } and the wrappers
mont_p, fr_add_p, fr_sub_p, fr_neg_p, fr_inv_p, plus field_load_ro/_rw,
field_store, is_zero_packed, eq_packed, get_p_packed, get_r_packed,
get_zero_packed. Each primitive body is a 3-line unpack-call-pack wrapper
around the existing BigInt-limb implementations. The 20x13-bit BigInt
representation only ever exists transiently inside these wrappers.
- wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl — new fused round
kernel. TPB threads cooperating on BATCH_SIZE = TPB*BS pairs per
workgroup with one fr_inv_by_a per workgroup. Direct port of the
bench_batch_affine design (validated at 22 ns/pair on M2) with
bucket-indirect loads/stores via pair_target_meta. Every field-element
variable is PackedField. The scheduler-enforced distinct-buckets
invariant means the phase-D scatter has zero intra-workgroup RAW
hazards.
- cuzk/shader_manager.ts — adds gen_batch_affine_fused_wg_scan_shader.
Validates TPB is a power of two for Hillis-Steele.
Not in this PR: msm_webgpu_v2 directory fork, packed ports of BPR /
horner / finalize / convert, host-side dispatch rewrite, e2e bench
harness. Those land in follow-up PRs once this layer is validated.
Validation done: yarn generate:wgsl regenerates cleanly, ShaderManager
end-to-end render produces 147 KB of WGSL with all mustache tags
substituted and every expected symbol present. Validation NOT done:
WebGPU compilation (no Playwright/SwiftShader in this container),
correctness vs CPU oracle, real-hardware perf — all needs the operator.
Forks bench-batch-affine.{ts,html} as bench-fused-wg-scan.{ts,html} —
exercises the new workgroup-scan fused round kernel against a synthetic
flat pool (one subtask, bucket_i = i, cursor_i = i, val_idx[i] = i)
with on-curve BN254 G1 affine pairs generated via noble. The output
R_i = P_i + Q_i is decoded from packed Mont form and compared to
noble's reference G1 add, so the bench is correctness-gated, not just
perf-gated.
- Sweep: BATCH_SIZE in {256, 512, 1024, 2048} at TOTAL_PAIRS = 65536
by default. Per-size correctness check followed by perf rep loop;
ns/pair printed per size for the M2 comparison.
- run-browserstack.mjs pageMap gets a new "bench-fused-wg-scan" entry
so the BrowserStack runner can drive it via --page.
No production code touched. Pure additive dev-page bench. Vite picks
the new .html up automatically.
Surfaced by running the bench harness on BrowserStack M2 (Chrome 148).
Each fix is a Tint diagnostic the local render check could not catch.
1. packed_field.template.wgsl — the comment referenced `{{{ dec_unpack }}}`
as literal mustache syntax. Mustache substituted the rendered
`unpack256_to_limbs` body INTO the comment, so its second line broke
out and Tint saw a function body at top level. Reworded the comment
to not reference mustache tags.
2. batch_affine_fused_wg_scan — `count_buf: array<atomic<u32>>` was
declared `read`, but atomics in storage must be `read_write` per
WGSL spec. Flipped to read_write; bench TS bind-group layout
updated to match.
3. batch_affine_fused_wg_scan — the early-return `if (batch_base >= n)
return;` and `if (chunk_len == 0u) return;` are guarded by values
derived from atomicLoad and so are considered non-uniform by Tint's
uniformity analysis. The subsequent workgroupBarrier then sat in
non-uniform control flow and was rejected. Restructured both phases
to NOT early-return: threads with no work clamp chunk_len=0 and
contribute identity (R = Mont 1) to the workgroup scan, skipping the
actual load/store loops but staying live through barriers.
4. packed_field.template.wgsl — added `fn get_r()` (plain BigInt
return) needed by fr_pow's body. Used to live in each call-site
shader template; living in the packed_field partial makes it
reusable across v2 shaders without duplication.
Correctness result on M2 after these fixes (TOTAL=4096, B=256):
correctness=pass (4096/4096 pairs vs noble reference)
median_ms=1.10, ns/pair=268 (small-N baseline; sweep to follow)
…d I/O only
After running the kernel three ways on BrowserStack M2:
- workgroup-scan + PackedField wrappers (binding-wide unpack/pack
inside every primitive) : 55 ns/pair
- per-thread BATCH=16 + BigInt internals + packed I/O : 56 ns/pair
- workgroup-scan + BigInt internals + BigInt prefix_buf : 76 ns/pair
All three correctness=pass. The PackedField wrapper layer adds ~30-50
ns/pair of pack/unpack overhead by repacking between primitive calls.
The right read of the design constraint is "unpack only at the storage
I/O boundary"; kernel-local vars and primitives stay in 20×13-bit
BigInt limb form for the whole kernel body.
This commit:
- packed_field.template.wgsl reduced to the storage-I/O boundary:
field_load_ro/_rw return BigInt directly (unpack at the load),
field_store packs a BigInt into the packed 8×u32 slot. No
PackedField struct, no mont_p/fr_*_p wrappers. get_r() stays.
- batch_affine_fused_wg_scan.template.wgsl: kept the workgroup-scan
shape (TPB=64 cooperate on TPB*BS pairs per workgroup with one
fr_inv_by_a per workgroup) but every field-element variable is
now BigInt. Primitives are the existing montgomery_product /
fr_sub / fr_add / fr_inv_by_a, no wrappers. prefix_buf moves to
array<BigInt> in storage.
- bench TS bind-group layout updated for the 9-binding shape
(added back prefix_buf). NUM_LIMBS_U32 inlined as 20.
WGSL/TS fixes surfaced during BS validation:
- `active` is a reserved keyword in WGSL — renamed local to `in_pool`.
- prefixBuf was sized using an undefined `NUM_LIMBS_U32` constant in
the bench harness; inlined the value (20).
Open questions captured in the analysis gist:
https://gist.github.com/AztecBot/6b3c2702f313a16fbd645355f1789fc3
- what fraction of the 22 ns → 55 ns gap is packed cost vs bucket
indirection vs Tint codegen quirk
- whether the bench's 22 ns/pair number still reproduces on current
M2 (BS was severely queue-wedged during this session; could not
re-measure)
Per-rep writeBuffer of runningXBuf/runningYBuf was inside the timed dispatch() call, adding ~4 MB of CPU→GPU upload to every sample. At M2 staging-upload rates that's ~1 ms per writeBuffer × 2 buffers ≈ 2 ms per rep. With observed per-dispatch wall ≈ 3.5 ms, the writeBuffer was nearly half the measured time. Now: writeBuffer once at setup (so the initial state is correct for the correctness check on warmup dispatch), and only on the explicit resetState=true path. Timed reps run with resetState=false — the kernel keeps modifying running_x/y in-place, which doesn't affect the throughput measurement (same load/sub/mul/store work per pair regardless of input values). Expected effect: my measured 55-76 ns/pair drops by the writeBuffer share once re-measured on BS.
…cked I/O)
Recreate the lost ba_rev_packed_carry_bench WGSL kernel: per-thread
descending suffix-product over BS pairs, single fr_inv_by_a per thread,
ascending lean affine apply. Packed 8 x u32 storage at the I/O boundary
only; every kernel-local var lives as 13-bit BigInt limbs. No bucket
indirection, no scheduler inputs — flat addressing only — so this is the
idealised standalone setting we use to calibrate the M2 22-24 ns/pair
target the prior session reproduced (iter24 rev_packed_s16 = 24.22
ns/pair) and to quantify the gap from standalone to the integrated
fused round kernel.
Wiring:
- new WGSL template at wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl
- gen_ba_rev_packed_carry_bench_shader(tpb, bs) on ShaderManager
- dev/msm-webgpu/bench-ba-rev-packed-carry.{ts,html} host harness with
noble on-curve oracle (R = P + Q), sweep over BS at fixed TPB
- run-browserstack.mjs pageMap entry "bench-ba-rev-packed-carry"
Replace the first-principles reconstruction with the canonical kernel recovered by the remote agent in commit eab3a3e. Differences that matter for the ~22 ns/pair M2 result: - bucket-accumulate streaming chain (A_{i+1} := P_i resident load-carry) instead of independent P+Q pairs — forward prefix-product, ONE fr_inv_by_a per S-chunk, backward peel with dx recomputed free - SoA-packed layout: one input buffer, 4 planes (A.x, A.y, P.x, P.y), each PG=2 vec4/elem; strided per-thread access e = t + i*T for full coalescing; separate 2-plane output buffer; params uniform = (N, T) - host harness measures DISP=8 back-to-back dispatches per timed sample (amortises submit+drain — the dominant source of the 29 vs 22 ns gap in the prior single-dispatch reconstruction), PAIRS=131072, sweep S in {16,32,64}, sanity = readNonZero on R.x plane shader_manager.gen_ba_rev_packed_carry_bench_shader now takes (workgroup_size, s) to match the recovered template's mustache vars.
…t expanded in-comment
The recovered template's header comment literally contained the
{{{ dec_unpack }}} / {{{ dec_pack }}} tag text; mustache expanded them
inside the // comment block, spilling the unpack256_to_limbs function
body out of the comment (WGSL: statement found outside of function
body). Reworded to name the injected helpers instead.
Standalone WebGPU bench wiring the recovered ba_rev_packed_carry chain
kernel into a Pippenger MSM bucket-accumulate pair-tree, measuring the
first reduction level (N input points -> N/2 useful pair sums).
Three pieces:
- ba_marshal_chain_bench.template.wgsl: pure memory-shuffle kernel.
Reads a CSR-sorted point index list + chunk plan, gathers packed
point coords from an SoA-packed pool, writes them into the strided
SoA layout (4 planes, PG=2 vec4/elem, T*S elements per plane,
element index e = t + i*T) the chain kernel consumes. point_pool[0]
is reserved as a universal decoy seed so the chain kernel's first dx
per chunk is well-defined; csr_indices values are 1-based.
- shader_manager.gen_ba_marshal_chain_shader(workgroup_size, s).
- dev/msm-webgpu/bench-msm-chain.{ts,html}: host harness.
- Synthetic CSR generator: N points uniformly assigned to B buckets,
sorted; row-pointer offsets + per-bucket counts. With N=131072,
B=8192 the average bucket size is 16, matching S=16 sweet spot.
- Chunk plan: for each bucket with count >= S, emit floor(count/S)
chunks of S consecutive points (tail points with count mod S are
skipped in v1 and reported as `tail_points`).
- Two-stage timed bench: marshal_ms via DISP=8 back-to-back dispatch
loop, then chain_ms separately. Reports marshal_ns_per_pt,
chain_ns_per_pt, combined_ns_per_pt, density = T*S/PAIRS.
- Sweep S in {16, 32, 64} at fixed TPB=64.
- Sanity: readNonZero on chain output (first packed elem of R.x).
Scope of v1 measures pair-tree level 0 only. Follow-ons not in this PR:
recursive level-1..log2(S) reduce passes (same chain kernel applied to
shrinking workload, ~25 ns/pt per level), tail-bucket handling for
count<S (workgroup-scan batch-affine fallback). Combined ns_per_pt
from this bench is the level-0 cost component; full MSM bucket
accumulate adds log2(S)/2 reduce levels on top.
…el waste)
Each thread reduces 2*S input points to S disjoint pair sums R_k =
P_{2k} + P_{2k+1} via the same forward-prefix-product / single
fr_inv_by_a / backward-peel batched-inverse pattern as
ba_rev_packed_carry, but with NO load-carry overlap. Every kernel
output is a distinct disjoint pair sum suitable as input to the next
pair-tree level. The chain kernel produced S overlapping sums of
which only S/2 were usable (R_1, R_3, ..., R_{S-1}); the disjoint
kernel produces S usable sums for the same per-thread inversion cost,
reclaiming the 50% kernel-efficiency loss.
Storage: SoA-packed 8x u32 per field (PG=2 vec4/elem). 2 input planes
(P.x, P.y) of 2*S*T elements each; 2 output planes (R.x, R.y) of S*T
each. dx values dx_k = P_{2k+1}.x - P_{2k}.x are mutually independent
(no shared inputs across k), so the Montgomery batched inverse trick
applies as-is — same Karatsuba+Yuval montmul, same BY-safegcd
fr_inv_by_a, ONE inversion amortised across the S pair sums per thread.
Files:
- wgsl/cuzk/ba_pair_disjoint_bench.template.wgsl
- gen_ba_pair_disjoint_bench_shader(workgroup_size, s) on ShaderManager
- dev/msm-webgpu/bench-ba-pair-disjoint.{ts,html} host harness with
DISP=8 dispatch amortisation matched to bench-ba-rev-packed-carry
for apples-to-apples ns/op comparison
- pageMap entry in run-browserstack.mjs
Expected: ns/op ~25 ns matching the chain kernel's per-add cost. Since
every output is now useful (vs S/2 useful in the chain), the ns/
useful-pair-sum metric should halve from ~50 to ~25 — the headline win
for full MSM bucket-accumulate via pair-tree reduction.
Three new kernels that compose into a full replacement for the cuZK
bucket-accumulate phase of Pippenger MSM:
1. ba_pair_disjoint_tree — tree variant of the disjoint pair-sum
kernel. Writes outputs in the LAYOUT THE NEXT PAIR-TREE LEVEL
EXPECTS, so multi-level reductions chain with no intervening
marshal pass:
out_pos(t, k) = (t >> 1) + (k + S * (t & 1)) * (T >> 1)
Final-level flag (params.z) switches to a simple strided write.
2. ba_marshal_tree_l0 — CSR -> level-0 strided 2-plane input layout.
Pure memory shuffle; same chunk-plan + CSR pattern as
ba_marshal_chain_bench but 2-plane output (no decoy seed).
3. ba_tail_reduce — handles small buckets (count < 2*S). One thread
per tail bucket, serial chain with one fr_inv_by_a per add. v1
pragmatic design with no batched inversion across threads; closes
correctness for buckets that can't use the disjoint kernel.
Host harness bench-msm-tree.{ts,html} drives the full pipeline:
marshal-l0 -> tree-disjoint level 0 -> level 1 -> ... -> final
tail
Reports per-stage and combined ns/in-pt over total points. Modes:
?mode=uniform : every bucket has exactly 2*S = 32 points (clean
multi-level test, no tail).
?mode=skewed : Poisson via uniform random scalar assignment, both
main path and tail exercised.
shader_manager.ts:
- gen_ba_pair_disjoint_tree_bench_shader(workgroup_size, s)
- gen_ba_marshal_tree_l0_bench_shader(workgroup_size, s)
- gen_ba_tail_reduce_bench_shader(workgroup_size, s) (TAIL_CAP = 2*S - 1)
run-browserstack.mjs: pageMap entry "bench-msm-tree".
This is the complete kernel set for plugging the disjoint pair-sum
batch-affine design into the production MSM bucket-accumulate phase.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Why
Foundation for the
msm_webgpu_v2rewrite described in the plan gist. Two intertwined problems with the currentfused_revcarrypath on the parent branch:batch_affine_fused_revcarry.template.wgslslices the per-subtask pair pool intoSCHUNK=16-pair per-thread chunks with onefr_inv_by_aper thread. The validated bench (bench_batch_affine.template.wgsl, 22 ns/pair on M2) uses a workgroup-level Hillis-Steele scan with onefr_inv_by_aper workgroup.{{#packed}}/{{^packed}}mustache pairs; the last 3 commits on this branch are layout-mismatch bugs from that approach. The v2 plan switches to packed-only with no fallback — primitives are the only place pack/unpack appears.What's in this PR
Foundation only — type system + one new kernel + a standalone correctness-gated microbench. Does NOT yet replace the round loop, port BPR / horner / finalize, or wire a v2 orchestrator.
wgsl/struct/packed_field.template.wgsl— definesstruct PackedField { lo: vec4<u32>, hi: vec4<u32> }and the wrappersmont_p,fr_add_p,fr_sub_p,fr_neg_p,fr_inv_p, plusfield_load_ro/_rw,field_store,is_zero_packed,eq_packed,get_p_packed,get_r_packed,get_zero_packed,get_r. Each primitive body is a 3-line unpack-call-pack around the existing BigInt-limb implementation.wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl— new fused round kernel. TPB threads cooperating onBATCH_SIZE = TPB*BSpairs per workgroup with onefr_inv_by_aper workgroup. Direct port of thebench_batch_affinedesign with bucket-indirect loads/stores viapair_target_meta. Every field-element variable isPackedField.cuzk/shader_manager.ts— addsgen_batch_affine_fused_wg_scan_shader(tpb, bs).dev/msm-webgpu/bench-fused-wg-scan.{ts,html}— standalone bench harness with full noble-curves correctness oracle (on-curve BN254 G1 pairs; GPU output decoded from packed Mont form, compared bit-exact toP.add(Q).toAffine()).dev/msm-webgpu/scripts/run-browserstack.mjs— addsbench-fused-wg-scanto the pageMap so the runner can drive it via--page.Validation on BrowserStack M2 (Chrome 148)
Correctness: every sweep size, every dispatched run, returned
correctness="pass"— all 4096 / 65536 R_i = P_i + Q_i pairs match noble's reference bit-exact.Perf curve at TOTAL=65536 pairs:
bench_batch_affine's reference number on the same hardware is ~22 ns/pair at B=1024 (BigInt-direct primitives, no bucket indirection). The new kernel sits at 2× the bench's theoretical ceiling at the best batch size — the residual gap is the PackedField wrapper overhead (2 unpacks + 1 pack per primitive call) plus the bucket-indirectpair_target_meta+val_idxlookups per pair.The new sweet spot is B=2048 (BS=32) rather than the bench's B=1024 — bigger BS amortises both the workgroup scan and the inversion better when the per-pair work is heavier (PackedField wrappers + bucket indirection). On the prior
fused_revcarryper-thread design the estimate was ~60-80 ns/pair on the same hardware, so this is a 30-40 % round-kernel improvement on top of design clarity.Sweep data points at TOTAL=4096 (GPU-underutilised regime, for completeness):
At small N the workgroup count drops below the GPU's SM count and the scan/inversion amortisation gets washed out by launch overhead — expected, only relevant as a sanity baseline.
What's NOT in this PR
Per the plan gist:
msm_webgpu_v2/directory scaffoldba_init,ba_schedule,ba_finalize_*,batch_inverse_parallel,convert_points_only,bpr_bn254,horner_reduce_bn254batch_affine.ts/msm.tsrewriteThree Tint-only fixes during validation
Surfaced by running the bench on BS M2 (Chrome 148) — none caught by the local render-time sanity check:
packed_field.template.wgsl— the comment referenced{{{ dec_unpack }}}as literal mustache syntax. Mustache substituted the renderedunpack256_to_limbsbody INTO the comment, so its second line broke out and Tint saw a function body at top level. Reworded the comment to not reference mustache tags.batch_affine_fused_wg_scan—count_buf: array<atomic<u32>>was declaredread; atomics in storage must beread_writeper WGSL spec. Flipped toread_write; bench TS bind-group layout updated to match.batch_affine_fused_wg_scan— the early-returnif (batch_base >= n) return;was guarded by an atomicLoad result, which Tint's uniformity analysis treats as non-uniform. The subsequentworkgroupBarrierwas rejected. Restructured to NOT early-return: threads with no work clampchunk_len=0and contribute identity (R= Mont 1) to the workgroup scan, skipping the actual load/store loops but staying live through barriers.Also added
fn get_r()to the packed_field partial sofr_pow_funcs' internal reference resolves.Next steps
msm_webgpu_v2/).References