diff --git a/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.html b/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.html
new file mode 100644
index 000000000000..95dc4bee2cdb
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.html
@@ -0,0 +1,37 @@
+
+
+
+
+ Disjoint pair-sum batch-affine bench (WebGPU)
+
+
+
+ Disjoint pair-sum batch-affine bench (WebGPU)
+ Query params: ?reps=R&pairs=N&wgi=W&disp=D&s=A,B,C
+
+
+
+
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.ts b/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.ts
new file mode 100644
index 000000000000..500fff0eba60
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.ts
@@ -0,0 +1,435 @@
+///
+// Standalone WebGPU bench for the disjoint pair-sum kernel
+// (ba_pair_disjoint_bench): each thread reduces 2*S input points to
+// S disjoint pair sums R_k = P_{2k} + P_{2k+1}, using one batched
+// fr_inv_by_a per chunk of S. Same DISP=8 dispatch amortisation as
+// bench-ba-rev-packed-carry to keep the measurement methodology
+// apples-to-apples.
+//
+// Input: random Montgomery field elems packed into SoA layout with
+// 2 planes (P.x, P.y), 2*PAIRS elements per plane. Pairs are arranged
+// at strided positions e = t + i*T for i in 0..2S so the kernel's
+// coalesced reads work. Adjacent pair members are guaranteed distinct
+// x to avoid the lean-add div-by-zero.
+//
+// Output: 2 planes (R.x, R.y), PAIRS elements per plane. Reports
+// ns/useful-pair-sum (every kernel output is a usable disjoint pair
+// sum, vs the chain kernel where only S/2 of S outputs are usable).
+
+import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js';
+import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js';
+import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js';
+import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js';
+import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js';
+import { makeResultsClient } from './results_post.js';
+
+const PG = 2;
+const DEFAULT_PAIRS = 1 << 17; // 131072 output pair sums per dispatch
+const DEFAULT_WGI = 64;
+const DEFAULT_DISP = 8;
+const DEFAULT_S_SWEEP: readonly number[] = [16, 32, 64];
+
+let PAIRS = DEFAULT_PAIRS;
+let WGI = DEFAULT_WGI;
+let DISP = DEFAULT_DISP;
+let S_SWEEP: readonly number[] = DEFAULT_S_SWEEP;
+
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+function randomBelow(p: bigint, rng: () => number): bigint {
+ const bitlen = p.toString(2).length;
+ const byteLen = Math.ceil(bitlen / 8);
+ for (;;) {
+ let v = 0n;
+ for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff);
+ v &= (1n << BigInt(bitlen)) - 1n;
+ if (v > 0n && v < p) return v;
+ }
+}
+
+function bigintToPackedU32x8(v: bigint): Uint32Array {
+ const w = new Uint32Array(8);
+ let x = v;
+ for (let i = 0; i < 8; i++) {
+ w[i] = Number(x & 0xffffffffn);
+ x >>= 32n;
+ }
+ return w;
+}
+
+function median(xs: number[]): number {
+ if (xs.length === 0) return NaN;
+ const s = xs.slice().sort((a, b) => a - b);
+ return s[Math.floor(s.length / 2)];
+}
+
+// SoA-packed input buffer with 2 planes (P.x, P.y), each PG*(2*PAIRS)
+// vec4. Plane p at element idx e: vec4 indices (p*PG + v)*N_in + e for
+// v in 0..PG, where N_in = 2*PAIRS. Adjacent pairs (2k, 2k+1) have
+// distinct x.
+function buildPackedPairsSoA(pairs: number, R: bigint, p: bigint, rng: () => number): Uint32Array {
+ const N_in = 2 * pairs;
+ const buf = new Uint32Array(2 * PG * N_in * 4);
+ for (let k = 0; k < pairs; k++) {
+ let lx: bigint;
+ let rx: bigint;
+ do {
+ lx = (randomBelow(p, rng) * R) % p;
+ rx = (randomBelow(p, rng) * R) % p;
+ } while (lx === rx);
+ const ly = (randomBelow(p, rng) * R) % p;
+ const ry = (randomBelow(p, rng) * R) % p;
+ const writeElem = (planeIdx: number, e: number, val: bigint) => {
+ const words = bigintToPackedU32x8(val);
+ for (let v = 0; v < PG; v++) {
+ const base = ((planeIdx * PG + v) * N_in + e) * 4;
+ buf[base + 0] = words[4 * v + 0];
+ buf[base + 1] = words[4 * v + 1];
+ buf[base + 2] = words[4 * v + 2];
+ buf[base + 3] = words[4 * v + 3];
+ }
+ };
+ writeElem(0, 2 * k + 0, lx);
+ writeElem(1, 2 * k + 0, ly);
+ writeElem(0, 2 * k + 1, rx);
+ writeElem(1, 2 * k + 1, ry);
+ }
+ return buf;
+}
+
+interface PerSizeResult {
+ s: number;
+ wgi: number;
+ T: number;
+ num_wgs: number;
+ pairs: number;
+ disp: number;
+ total_ops: number;
+ median_ms: number;
+ min_ms: number;
+ max_ms: number;
+ ns_per_op: number;
+ samples_ms: number[];
+ sanity_ok: boolean;
+}
+
+interface BenchState {
+ state: 'boot' | 'running' | 'done' | 'error';
+ params: { reps: number; pairs: number; wgi: number; disp: number; s_sweep: readonly number[] } | null;
+ results: PerSizeResult[];
+ error: string | null;
+ log: string[];
+}
+
+const benchState: BenchState = {
+ state: 'boot',
+ params: null,
+ results: [],
+ error: null,
+ log: [],
+};
+(window as unknown as { __bench: BenchState }).__bench = benchState;
+
+const resultsClient = makeResultsClient({ page: 'bench-ba-pair-disjoint' });
+(window as unknown as { __runId: string }).__runId = resultsClient.runId;
+
+async function postFinal(): Promise {
+ await resultsClient.postResults({
+ state: benchState.state,
+ params: benchState.params,
+ results: benchState.results,
+ error: benchState.error,
+ log: benchState.log,
+ userAgent: navigator.userAgent,
+ hardwareConcurrency: navigator.hardwareConcurrency,
+ });
+}
+
+const $log = document.getElementById('log') as HTMLDivElement;
+function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) {
+ const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : '';
+ const span = document.createElement('div');
+ span.className = cls;
+ span.textContent = msg;
+ $log.appendChild(span);
+ benchState.log.push(`[${level}] ${msg}`);
+ console.log(`[bench-ba-pair-disjoint] ${msg}`);
+}
+
+async function compile(
+ device: GPUDevice,
+ code: string,
+ cacheKey: string,
+): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> {
+ const module = device.createShaderModule({ code });
+ const info = await module.getCompilationInfo();
+ let hasError = false;
+ const errLines: string[] = [];
+ for (const msg of info.messages) {
+ const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`;
+ if (msg.type === 'error') {
+ console.error(line);
+ log('err', line);
+ errLines.push(line);
+ hasError = true;
+ } else {
+ console.warn(line);
+ log('warn', line);
+ }
+ }
+ if (hasError) throw new Error(`WGSL compile failed for ${cacheKey}: ${errLines.join(' | ')}`);
+ const layout = device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+ const pipeline = await device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+ return { pipeline, layout };
+}
+
+async function readNonZero(device: GPUDevice, buf: GPUBuffer, u32Count: number): Promise {
+ const bytes = u32Count * 4;
+ const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST });
+ const enc = device.createCommandEncoder();
+ enc.copyBufferToBuffer(buf, 0, staging, 0, bytes);
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await staging.mapAsync(GPUMapMode.READ);
+ const u32 = new Uint32Array(staging.getMappedRange().slice(0));
+ staging.unmap();
+ staging.destroy();
+ for (let i = 0; i < u32.length; i++) if (u32[i] !== 0) return true;
+ return false;
+}
+
+async function timeDispatch(
+ device: GPUDevice,
+ pipeline: GPUComputePipeline,
+ bind: GPUBindGroup,
+ numWgs: number,
+ reps: number,
+ passes: number,
+): Promise {
+ {
+ const enc = device.createCommandEncoder();
+ for (let pIdx = 0; pIdx < passes; pIdx++) {
+ const pass = enc.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bind);
+ pass.dispatchWorkgroups(numWgs, 1, 1);
+ pass.end();
+ }
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ }
+ const samples: number[] = [];
+ for (let r = 0; r < reps; r++) {
+ const enc = device.createCommandEncoder();
+ for (let pIdx = 0; pIdx < passes; pIdx++) {
+ const pass = enc.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bind);
+ pass.dispatchWorkgroups(numWgs, 1, 1);
+ pass.end();
+ }
+ const t0 = performance.now();
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ samples.push(performance.now() - t0);
+ }
+ return samples;
+}
+
+async function runOne(
+ device: GPUDevice,
+ sm: ShaderManager,
+ s: number,
+ reps: number,
+ R: bigint,
+ p: bigint,
+ seed: number,
+): Promise {
+ if (PAIRS % s !== 0) throw new Error(`PAIRS=${PAIRS} must be a multiple of S=${s}`);
+ const T = PAIRS / s;
+ const numWgs = Math.ceil(T / WGI);
+ log('info', `=== S=${s}: PAIRS=${PAIRS} T=${T} WGI=${WGI} numWgs=${numWgs} DISP=${DISP}`);
+
+ const code = sm.gen_ba_pair_disjoint_bench_shader(WGI, s);
+ log('info', `compiling shader (${code.length} chars)`);
+ (window as unknown as Record)[`__shader_s${s}`] = code;
+ const cacheKey = `bench-ba-pair-disjoint-W${WGI}-S${s}`;
+ const { pipeline, layout } = await compile(device, code, cacheKey);
+
+ const rng = makeRng(seed);
+ const inU32 = buildPackedPairsSoA(PAIRS, R, p, rng);
+
+ const inBuf = device.createBuffer({
+ size: inU32.byteLength,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(inBuf, 0, inU32);
+ const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE });
+ const outBytes = 2 * PG * PAIRS * 4 * 4;
+ const outBuf = device.createBuffer({
+ size: outBytes,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
+ });
+ const paramsBuf = device.createBuffer({
+ size: 16,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(paramsBuf, 0, new Uint32Array([2 * PAIRS, T, 0, 0]));
+
+ const bind = device.createBindGroup({
+ layout,
+ entries: [
+ { binding: 0, resource: { buffer: inBuf } },
+ { binding: 1, resource: { buffer: dummy } },
+ { binding: 2, resource: { buffer: outBuf } },
+ { binding: 3, resource: { buffer: paramsBuf } },
+ ],
+ });
+
+ const samples = await timeDispatch(device, pipeline, bind, numWgs, reps, DISP);
+ const sanityOk = await readNonZero(device, outBuf, 8);
+ const med = median(samples);
+ const totalOps = PAIRS * DISP;
+ const nsPerOp = (med * 1e6) / totalOps;
+
+ log(
+ sanityOk ? 'ok' : 'err',
+ `S=${s}: median=${med.toFixed(3)}ms min=${Math.min(...samples).toFixed(3)}ms max=${Math.max(...samples).toFixed(3)}ms ns/op=${nsPerOp.toFixed(2)} sanity=${sanityOk ? 'OK' : 'FAIL'}`,
+ );
+
+ inBuf.destroy();
+ dummy.destroy();
+ outBuf.destroy();
+ paramsBuf.destroy();
+
+ return {
+ s,
+ wgi: WGI,
+ T,
+ num_wgs: numWgs,
+ pairs: PAIRS,
+ disp: DISP,
+ total_ops: totalOps,
+ median_ms: med,
+ min_ms: Math.min(...samples),
+ max_ms: Math.max(...samples),
+ ns_per_op: nsPerOp,
+ samples_ms: samples,
+ sanity_ok: sanityOk,
+ };
+}
+
+function parseParams() {
+ const qp = new URLSearchParams(window.location.search);
+ const reps = parseInt(qp.get('reps') ?? '5', 10);
+ if (!Number.isFinite(reps) || reps <= 0 || reps > 50) throw new Error(`?reps must be in (0, 50]`);
+ const pairsStr = qp.get('pairs');
+ if (pairsStr !== null) {
+ const v = parseInt(pairsStr, 10);
+ if (!Number.isFinite(v) || v <= 0 || v > (1 << 20)) throw new Error(`?pairs must be in (0, 2^20]`);
+ PAIRS = v;
+ }
+ const wgiStr = qp.get('wgi');
+ if (wgiStr !== null) {
+ const v = parseInt(wgiStr, 10);
+ if (!Number.isFinite(v) || v <= 0 || v > 1024) throw new Error(`?wgi must be in (0, 1024]`);
+ WGI = v;
+ }
+ const dispStr = qp.get('disp');
+ if (dispStr !== null) {
+ const v = parseInt(dispStr, 10);
+ if (!Number.isFinite(v) || v <= 0 || v > 64) throw new Error(`?disp must be in (0, 64]`);
+ DISP = v;
+ }
+ const sStr = qp.get('s');
+ if (sStr !== null) {
+ const list = sStr.split(',').map(v => parseInt(v, 10));
+ for (const v of list) {
+ if (!Number.isFinite(v) || v <= 0 || v > 256) throw new Error(`?s entries must be in (0, 256]`);
+ }
+ S_SWEEP = list;
+ }
+ for (const v of S_SWEEP) {
+ if (PAIRS % v !== 0) throw new Error(`S=${v} does not divide PAIRS=${PAIRS}`);
+ }
+ return { reps, pairs: PAIRS, wgi: WGI, disp: DISP, s_sweep: S_SWEEP };
+}
+
+async function main() {
+ try {
+ if (!('gpu' in navigator)) throw new Error('navigator.gpu missing — WebGPU not available');
+ const params = parseParams();
+ benchState.params = params;
+ log(
+ 'info',
+ `params: reps=${params.reps} pairs=${params.pairs} wgi=${params.wgi} disp=${params.disp} s=[${params.s_sweep.join(',')}]`,
+ );
+
+ benchState.state = 'running';
+ const device = await get_device();
+ log('info', 'WebGPU device acquired');
+
+ const p = BN254_BASE_FIELD;
+ const miscParams = compute_misc_params(p, 13);
+ const R = miscParams.r;
+
+ const sm = new ShaderManager(4, PAIRS, BN254_CURVE_CONFIG, false);
+
+ let seed = 0xd1d1;
+ for (const s of S_SWEEP) {
+ try {
+ const r = await runOne(device, sm, s, params.reps, R, p, seed);
+ benchState.results.push(r);
+ resultsClient.postProgress({
+ kind: 'batch_done',
+ s,
+ median_ms: r.median_ms,
+ ns_per_op: r.ns_per_op,
+ sanity_ok: r.sanity_ok,
+ });
+ seed += 0x10;
+ } catch (e) {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `S=${s} failed: ${msg} — STOPPING`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ return;
+ }
+ }
+
+ benchState.state = 'done';
+ log('ok', 'all sizes done');
+ } catch (e) {
+ const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e);
+ log('err', `FATAL: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ }
+}
+
+main()
+ .catch(e => {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `unhandled: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ })
+ .finally(() => {
+ postFinal().catch(() => {});
+ });
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.html b/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.html
new file mode 100644
index 000000000000..6676d4772bd2
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.html
@@ -0,0 +1,37 @@
+
+
+
+
+ ba_rev_packed_carry standalone batch-affine bench (WebGPU)
+
+
+
+ ba_rev_packed_carry standalone batch-affine bench (WebGPU)
+ Query params: ?reps=R&pairs=N&wgi=W&disp=D&s=A,B,C
+
+
+
+
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.ts b/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.ts
new file mode 100644
index 000000000000..2e55c2c9b7e4
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.ts
@@ -0,0 +1,438 @@
+///
+// Standalone WebGPU bench for the canonical ba_rev_packed_carry kernel
+// (recovered from commit eab3a3e). SoA-packed 8x u32 storage across 4
+// input planes (A.x, A.y, P.x, P.y), strided per-thread access
+// e = t + i*T, single fr_inv_by_a per S-chunk, lean affine apply with
+// resident-accumulator load-carry. Each timed sample submits DISP back-
+// to-back dispatches in one command encoder to amortise submit+drain.
+//
+// Math: bucket-accumulate streaming chain (R_i = A_i + P_i where A_0 is
+// the seed, A_{i+1} := P_i). With random P.x != A.x inputs, every dx is
+// nonzero so the batched inverse is well-defined. Sanity check reads
+// the first packed elem of R.x and asserts non-zero.
+
+import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js';
+import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js';
+import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js';
+import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js';
+import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js';
+import { makeResultsClient } from './results_post.js';
+
+const DEFAULT_PAIRS = 1 << 17; // 131072
+const DEFAULT_WGI = 64;
+const DEFAULT_DISP = 8;
+const DEFAULT_S_SWEEP: readonly number[] = [16, 32, 64];
+const PG = 2; // 8 packed u32 / 4 = 2 vec4 groups per element
+
+let PAIRS = DEFAULT_PAIRS;
+let WGI = DEFAULT_WGI;
+let DISP = DEFAULT_DISP;
+let S_SWEEP: readonly number[] = DEFAULT_S_SWEEP;
+
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+function randomBelow(p: bigint, rng: () => number): bigint {
+ const bitlen = p.toString(2).length;
+ const byteLen = Math.ceil(bitlen / 8);
+ for (;;) {
+ let v = 0n;
+ for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff);
+ v &= (1n << BigInt(bitlen)) - 1n;
+ if (v > 0n && v < p) return v;
+ }
+}
+
+function bigintToPackedU32x8(v: bigint): Uint32Array {
+ const w = new Uint32Array(8);
+ let x = v;
+ for (let i = 0; i < 8; i++) {
+ w[i] = Number(x & 0xffffffffn);
+ x >>= 32n;
+ }
+ return w;
+}
+
+// SoA layout: 4 planes (A.x, A.y, P.x, P.y), each plane = PG * N vec4.
+// For plane c and pair e, vec4 indices = (c * PG + v) * N + e for v in
+// [0, PG). The flat Uint32Array index is that times 4.
+function packAffineSoAPacked(pairs: number, R: bigint, p: bigint, rng: () => number): Uint32Array {
+ const N = pairs;
+ const buf = new Uint32Array(4 * PG * N * 4);
+ for (let e = 0; e < N; e++) {
+ let pxM: bigint;
+ let qxM: bigint;
+ do {
+ pxM = (randomBelow(p, rng) * R) % p;
+ qxM = (randomBelow(p, rng) * R) % p;
+ } while (pxM === qxM);
+ const coords = [pxM, (randomBelow(p, rng) * R) % p, qxM, (randomBelow(p, rng) * R) % p];
+ for (let c = 0; c < 4; c++) {
+ const words = bigintToPackedU32x8(coords[c]);
+ for (let v = 0; v < PG; v++) {
+ const base = ((c * PG + v) * N + e) * 4;
+ buf[base + 0] = words[4 * v + 0];
+ buf[base + 1] = words[4 * v + 1];
+ buf[base + 2] = words[4 * v + 2];
+ buf[base + 3] = words[4 * v + 3];
+ }
+ }
+ }
+ return buf;
+}
+
+function median(xs: number[]): number {
+ if (xs.length === 0) return NaN;
+ const s = xs.slice().sort((a, b) => a - b);
+ return s[Math.floor(s.length / 2)];
+}
+
+interface PerSizeResult {
+ s: number;
+ wgi: number;
+ T: number;
+ num_wgs: number;
+ pairs: number;
+ disp: number;
+ total_ops: number;
+ median_ms: number;
+ min_ms: number;
+ max_ms: number;
+ ns_per_op: number;
+ samples_ms: number[];
+ sanity_ok: boolean;
+}
+
+interface BenchState {
+ state: 'boot' | 'running' | 'done' | 'error';
+ params: { reps: number; pairs: number; wgi: number; disp: number; s_sweep: readonly number[] } | null;
+ results: PerSizeResult[];
+ error: string | null;
+ log: string[];
+}
+
+const benchState: BenchState = {
+ state: 'boot',
+ params: null,
+ results: [],
+ error: null,
+ log: [],
+};
+(window as unknown as { __bench: BenchState }).__bench = benchState;
+
+const resultsClient = makeResultsClient({ page: 'bench-ba-rev-packed-carry' });
+(window as unknown as { __runId: string }).__runId = resultsClient.runId;
+
+async function postFinal(): Promise {
+ await resultsClient.postResults({
+ state: benchState.state,
+ params: benchState.params,
+ results: benchState.results,
+ error: benchState.error,
+ log: benchState.log,
+ userAgent: navigator.userAgent,
+ hardwareConcurrency: navigator.hardwareConcurrency,
+ });
+}
+
+const $log = document.getElementById('log') as HTMLDivElement;
+function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) {
+ const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : '';
+ const span = document.createElement('div');
+ span.className = cls;
+ span.textContent = msg;
+ $log.appendChild(span);
+ benchState.log.push(`[${level}] ${msg}`);
+ console.log(`[bench-ba-rev-packed-carry] ${msg}`);
+}
+
+async function createPipeline(
+ device: GPUDevice,
+ code: string,
+ cacheKey: string,
+): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> {
+ const module = device.createShaderModule({ code });
+ const info = await module.getCompilationInfo();
+ let hasError = false;
+ const errLines: string[] = [];
+ for (const msg of info.messages) {
+ const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`;
+ if (msg.type === 'error') {
+ console.error(line);
+ log('err', line);
+ errLines.push(line);
+ hasError = true;
+ } else {
+ console.warn(line);
+ log('warn', line);
+ }
+ }
+ if (hasError) {
+ throw new Error(`WGSL compile failed for ${cacheKey}: ${errLines.join(' | ')}`);
+ }
+ const layout = device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+ const pipeline = await device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+ return { pipeline, layout };
+}
+
+async function readNonZero(device: GPUDevice, out: GPUBuffer, u32Count: number): Promise {
+ const bytes = u32Count * 4;
+ const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST });
+ const enc = device.createCommandEncoder();
+ enc.copyBufferToBuffer(out, 0, staging, 0, bytes);
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await staging.mapAsync(GPUMapMode.READ);
+ const u32 = new Uint32Array(staging.getMappedRange().slice(0));
+ staging.unmap();
+ staging.destroy();
+ for (let i = 0; i < u32.length; i++) {
+ if (u32[i] !== 0) return true;
+ }
+ return false;
+}
+
+async function timeDispatch(
+ device: GPUDevice,
+ pipeline: GPUComputePipeline,
+ bind: GPUBindGroup,
+ numWgs: number,
+ reps: number,
+ passes: number,
+): Promise {
+ // warmup
+ {
+ const enc = device.createCommandEncoder();
+ for (let pIdx = 0; pIdx < passes; pIdx++) {
+ const pass = enc.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bind);
+ pass.dispatchWorkgroups(numWgs, 1, 1);
+ pass.end();
+ }
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ }
+ const samples: number[] = [];
+ for (let r = 0; r < reps; r++) {
+ const enc = device.createCommandEncoder();
+ for (let pIdx = 0; pIdx < passes; pIdx++) {
+ const pass = enc.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bind);
+ pass.dispatchWorkgroups(numWgs, 1, 1);
+ pass.end();
+ }
+ const t0 = performance.now();
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ samples.push(performance.now() - t0);
+ }
+ return samples;
+}
+
+async function runOne(
+ device: GPUDevice,
+ sm: ShaderManager,
+ s: number,
+ reps: number,
+ R: bigint,
+ p: bigint,
+ seed: number,
+): Promise {
+ if (PAIRS % s !== 0) {
+ throw new Error(`PAIRS=${PAIRS} must be a multiple of S=${s}`);
+ }
+ const T = PAIRS / s;
+ const numWgs = Math.ceil(T / WGI);
+ log('info', `=== S=${s}: PAIRS=${PAIRS} T=${T} WGI=${WGI} numWgs=${numWgs} DISP=${DISP}`);
+
+ const code = sm.gen_ba_rev_packed_carry_bench_shader(WGI, s);
+ const cacheKey = `bench-ba-rev-packed-carry-W${WGI}-S${s}`;
+ log('info', `compiling shader (${code.length} chars)`);
+ (window as unknown as Record)[`__shader_s${s}`] = code;
+ const { pipeline, layout } = await createPipeline(device, code, cacheKey);
+
+ const rng = makeRng(seed);
+ const inU32 = packAffineSoAPacked(PAIRS, R, p, rng);
+ const inBuf = device.createBuffer({
+ size: inU32.byteLength,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(inBuf, 0, inU32);
+ const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE });
+ const outBytes = 2 * PG * PAIRS * 4 * 4;
+ const outBuf = device.createBuffer({
+ size: outBytes,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
+ });
+ const paramsBuf = device.createBuffer({
+ size: 16,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(paramsBuf, 0, new Uint32Array([PAIRS, T, 0, 0]));
+
+ const bind = device.createBindGroup({
+ layout,
+ entries: [
+ { binding: 0, resource: { buffer: inBuf } },
+ { binding: 1, resource: { buffer: dummy } },
+ { binding: 2, resource: { buffer: outBuf } },
+ { binding: 3, resource: { buffer: paramsBuf } },
+ ],
+ });
+
+ const samples = await timeDispatch(device, pipeline, bind, numWgs, reps, DISP);
+ const sanityOk = await readNonZero(device, outBuf, 8);
+ const med = median(samples);
+ const totalOps = PAIRS * DISP;
+ const nsPerOp = (med * 1e6) / totalOps;
+
+ log(
+ sanityOk ? 'ok' : 'err',
+ `S=${s}: median=${med.toFixed(3)}ms min=${Math.min(...samples).toFixed(3)}ms max=${Math.max(...samples).toFixed(3)}ms ns/op=${nsPerOp.toFixed(2)} sanity=${sanityOk ? 'OK' : 'FAIL'}`,
+ );
+
+ inBuf.destroy();
+ dummy.destroy();
+ outBuf.destroy();
+ paramsBuf.destroy();
+
+ return {
+ s,
+ wgi: WGI,
+ T,
+ num_wgs: numWgs,
+ pairs: PAIRS,
+ disp: DISP,
+ total_ops: totalOps,
+ median_ms: med,
+ min_ms: Math.min(...samples),
+ max_ms: Math.max(...samples),
+ ns_per_op: nsPerOp,
+ samples_ms: samples,
+ sanity_ok: sanityOk,
+ };
+}
+
+function parseParams() {
+ const qp = new URLSearchParams(window.location.search);
+ const reps = parseInt(qp.get('reps') ?? '5', 10);
+ if (!Number.isFinite(reps) || reps <= 0 || reps > 50) {
+ throw new Error(`?reps must be in (0, 50], got ${qp.get('reps')}`);
+ }
+ const pairsStr = qp.get('pairs');
+ if (pairsStr !== null) {
+ const v = parseInt(pairsStr, 10);
+ if (!Number.isFinite(v) || v <= 0 || v > (1 << 20)) {
+ throw new Error(`?pairs must be in (0, 2^20], got ${pairsStr}`);
+ }
+ PAIRS = v;
+ }
+ const wgiStr = qp.get('wgi');
+ if (wgiStr !== null) {
+ const v = parseInt(wgiStr, 10);
+ if (!Number.isFinite(v) || v <= 0 || v > 1024) {
+ throw new Error(`?wgi must be in (0, 1024], got ${wgiStr}`);
+ }
+ WGI = v;
+ }
+ const dispStr = qp.get('disp');
+ if (dispStr !== null) {
+ const v = parseInt(dispStr, 10);
+ if (!Number.isFinite(v) || v <= 0 || v > 64) {
+ throw new Error(`?disp must be in (0, 64], got ${dispStr}`);
+ }
+ DISP = v;
+ }
+ const sStr = qp.get('s');
+ if (sStr !== null) {
+ const list = sStr.split(',').map(v => parseInt(v, 10));
+ for (const v of list) {
+ if (!Number.isFinite(v) || v <= 0 || v > 256) {
+ throw new Error(`?s entries must be in (0, 256], got ${v}`);
+ }
+ }
+ S_SWEEP = list;
+ }
+ for (const v of S_SWEEP) {
+ if (PAIRS % v !== 0) {
+ throw new Error(`S=${v} does not divide PAIRS=${PAIRS}`);
+ }
+ }
+ return { reps, pairs: PAIRS, wgi: WGI, disp: DISP, s_sweep: S_SWEEP };
+}
+
+async function main() {
+ try {
+ if (!('gpu' in navigator)) {
+ throw new Error('navigator.gpu missing — WebGPU not available');
+ }
+ const params = parseParams();
+ benchState.params = params;
+ log(
+ 'info',
+ `params: reps=${params.reps} pairs=${params.pairs} wgi=${params.wgi} disp=${params.disp} s=[${params.s_sweep.join(',')}]`,
+ );
+
+ benchState.state = 'running';
+ const device = await get_device();
+ log('info', 'WebGPU device acquired');
+
+ const p = BN254_BASE_FIELD;
+ const miscParams = compute_misc_params(p, 13);
+ const R = miscParams.r;
+
+ const sm = new ShaderManager(4, PAIRS, BN254_CURVE_CONFIG, false);
+
+ let seed = 0x7b10;
+ for (const s of S_SWEEP) {
+ try {
+ const r = await runOne(device, sm, s, params.reps, R, p, seed);
+ benchState.results.push(r);
+ resultsClient.postProgress({ kind: 'batch_done', s, median_ms: r.median_ms, ns_per_op: r.ns_per_op, sanity_ok: r.sanity_ok });
+ seed += 0x10;
+ } catch (e) {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `S=${s} failed: ${msg} — STOPPING sweep at first failure`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ return;
+ }
+ }
+
+ benchState.state = 'done';
+ log('ok', 'all sizes done');
+ } catch (e) {
+ const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e);
+ log('err', `FATAL: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ }
+}
+
+main()
+ .catch(e => {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `unhandled: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ })
+ .finally(() => {
+ postFinal().catch(() => {});
+ });
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.html b/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.html
new file mode 100644
index 000000000000..31b300972d08
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.html
@@ -0,0 +1,22 @@
+
+
+
+
+ cuZK CSR -> v2 active_sums layout converter bench (WebGPU)
+
+
+
+ cuZK CSR -> v2 active_sums layout converter
+ Query params: ?subtasks=T&columns=C&input=N&wg=W&disp=D&reps=R&validate=1
+
+
+
+
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.ts b/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.ts
new file mode 100644
index 000000000000..44a0458b96d1
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.ts
@@ -0,0 +1,517 @@
+///
+// Standalone microbench + validator for the cuZK CSR -> v2 active_sums
+// layout converter. The converter consumes the cuZK transpose output
+// (val_idx + row_ptr) and the packed 8xu32 cached bases, and emits the
+// bucket-major active_sums + per-bucket counts/offsets that the v2
+// pair-tree expects.
+//
+// Inputs are synthetic — random bucket assignments per (subtask, scalar)
+// — so the harness validates the converter in isolation without needing
+// the full cuZK convert + decompose + transpose pipeline.
+//
+// Validation: byte-equivalent compare of active_sums_x/y, active_counts,
+// and active_offsets against a host-side reference. The reference is
+// trivial: active_sums[k] = new_point[val_idx[k]], counts[b] =
+// row_ptr[b+1]-row_ptr[b], offsets[b] = row_ptr[b].
+//
+// Timing methodology mirrors bench-planner: warmup, then DISP back-to-
+// back dispatches in one encoder, divided by DISP for per-call time.
+
+import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js';
+import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js';
+import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js';
+import { makeResultsClient } from './results_post.js';
+
+let NUM_SUBTASKS = 16;
+let NUM_COLUMNS = 4096;
+let INPUT_SIZE = 1 << 14;
+let WG = 64;
+let DISP = 32;
+let REPS = 5;
+let VALIDATE = false;
+let SEED = 0xc5af;
+
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+function buildSyntheticCsr(
+ numSubtasks: number,
+ numColumns: number,
+ inputSize: number,
+ seed: number,
+): { rowPtr: Uint32Array; valIdx: Uint32Array } {
+ const rng = makeRng(seed);
+ const rowPtr = new Uint32Array(numSubtasks * (numColumns + 1));
+ const valIdx = new Uint32Array(numSubtasks * inputSize);
+
+ for (let s = 0; s < numSubtasks; s++) {
+ const bucketOf = new Uint32Array(inputSize);
+ const perBucketCount = new Uint32Array(numColumns);
+ for (let i = 0; i < inputSize; i++) {
+ const hi = (rng() >>> 16) & 0xffff;
+ const lo = (rng() >>> 16) & 0xffff;
+ const r = hi * 0x10000 + lo;
+ const b = r % numColumns;
+ bucketOf[i] = b;
+ perBucketCount[b]++;
+ }
+ const rpBase = s * (numColumns + 1);
+ rowPtr[rpBase] = 0;
+ for (let b = 0; b < numColumns; b++) {
+ rowPtr[rpBase + b + 1] = rowPtr[rpBase + b] + perBucketCount[b];
+ }
+ const cursor = new Uint32Array(numColumns);
+ const viBase = s * inputSize;
+ for (let i = 0; i < inputSize; i++) {
+ const b = bucketOf[i];
+ const slot = rowPtr[rpBase + b] + cursor[b];
+ valIdx[viBase + slot] = i;
+ cursor[b]++;
+ }
+ }
+ return { rowPtr, valIdx };
+}
+
+function median(xs: number[]): number {
+ if (xs.length === 0) return NaN;
+ const s = xs.slice().sort((a, b) => a - b);
+ return s[Math.floor(s.length / 2)];
+}
+
+interface BenchResult {
+ num_subtasks: number;
+ num_columns: number;
+ input_size: number;
+ total_slots: number;
+ total_buckets: number;
+ wg: number;
+ disp: number;
+ reps: number;
+ active_sums_us: { min: number; median: number; max: number };
+ meta_us: { min: number; median: number; max: number };
+ combined_us: { min: number; median: number; max: number };
+ ns_per_slot_combined: number;
+ validated: boolean;
+}
+
+interface BenchState {
+ state: 'boot' | 'running' | 'done' | 'error';
+ params: Record | null;
+ results: BenchResult[];
+ error: string | null;
+ log: string[];
+}
+
+const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] };
+(window as unknown as { __bench: BenchState }).__bench = benchState;
+const resultsClient = makeResultsClient({ page: 'bench-csr-to-v2' });
+(window as unknown as { __runId: string }).__runId = resultsClient.runId;
+
+async function postFinal(): Promise {
+ await resultsClient.postResults({
+ state: benchState.state,
+ params: benchState.params,
+ results: benchState.results,
+ error: benchState.error,
+ log: benchState.log,
+ userAgent: navigator.userAgent,
+ hardwareConcurrency: navigator.hardwareConcurrency,
+ });
+}
+
+const $log = document.getElementById('log') as HTMLDivElement;
+function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) {
+ const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : '';
+ const span = document.createElement('div');
+ span.className = cls;
+ span.textContent = msg;
+ $log.appendChild(span);
+ benchState.log.push(`[${level}] ${msg}`);
+ console.log(`[bench-csr-to-v2] ${msg}`);
+}
+
+async function compileOne(
+ device: GPUDevice,
+ code: string,
+ key: string,
+ layout: GPUBindGroupLayout,
+): Promise {
+ const module = device.createShaderModule({ code });
+ const info = await module.getCompilationInfo();
+ let hasError = false;
+ const errLines: string[] = [];
+ for (const m of info.messages) {
+ const line = `[shader ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`;
+ if (m.type === 'error') {
+ console.error(line);
+ log('err', line);
+ errLines.push(line);
+ hasError = true;
+ } else {
+ console.warn(line);
+ }
+ }
+ if (hasError) throw new Error(`WGSL compile failed for ${key}: ${errLines.slice(0, 4).join(' | ')}`);
+ return device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+}
+
+async function readbackU32(device: GPUDevice, buf: GPUBuffer, byteLength: number): Promise {
+ const staging = device.createBuffer({ size: byteLength, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST });
+ const enc = device.createCommandEncoder();
+ enc.copyBufferToBuffer(buf, 0, staging, 0, byteLength);
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await staging.mapAsync(GPUMapMode.READ);
+ const out = new Uint32Array(staging.getMappedRange().slice(0));
+ staging.unmap();
+ staging.destroy();
+ return out;
+}
+
+async function runOne(device: GPUDevice, sm: ShaderManager): Promise {
+ const totalSlots = NUM_SUBTASKS * INPUT_SIZE;
+ const totalBuckets = NUM_SUBTASKS * NUM_COLUMNS;
+ log(
+ 'info',
+ `=== T=${NUM_SUBTASKS} columns=${NUM_COLUMNS} input_size=${INPUT_SIZE} ` +
+ `total_slots=${totalSlots} total_buckets=${totalBuckets} WG=${WG} DISP=${DISP} REPS=${REPS}`,
+ );
+
+ // Synthetic packed bases: 32 bytes per element. Random bytes are fine
+ // — the converter is a layout-equivalent byte copy.
+ const rng = makeRng(SEED ^ 0xfeed);
+ const basesBytes = INPUT_SIZE * 32;
+ const basesX = new Uint8Array(basesBytes);
+ const basesY = new Uint8Array(basesBytes);
+ for (let i = 0; i < basesBytes; i += 4) {
+ const v = rng();
+ basesX[i] = v & 0xff;
+ basesX[i + 1] = (v >>> 8) & 0xff;
+ basesX[i + 2] = (v >>> 16) & 0xff;
+ basesX[i + 3] = (v >>> 24) & 0xff;
+ const w = rng();
+ basesY[i] = w & 0xff;
+ basesY[i + 1] = (w >>> 8) & 0xff;
+ basesY[i + 2] = (w >>> 16) & 0xff;
+ basesY[i + 3] = (w >>> 24) & 0xff;
+ }
+ log('info', `synthetic packed bases: ${(basesBytes * 2 / 1024).toFixed(1)} KiB total`);
+
+ const { rowPtr, valIdx } = buildSyntheticCsr(NUM_SUBTASKS, NUM_COLUMNS, INPUT_SIZE, SEED);
+ log('info', `synthetic CSR: rowPtr=${(rowPtr.byteLength / 1024).toFixed(1)} KiB valIdx=${(valIdx.byteLength / 1024).toFixed(1)} KiB`);
+
+ const mk = (bytes: number, copySrc = false, copyDst = false): GPUBuffer => {
+ let usage = GPUBufferUsage.STORAGE;
+ if (copySrc) usage |= GPUBufferUsage.COPY_SRC;
+ if (copyDst) usage |= GPUBufferUsage.COPY_DST;
+ return device.createBuffer({ size: bytes, usage });
+ };
+
+ const M = totalSlots;
+ const activeBytes = 2 * 2 * M * 16;
+ const basesXBuf = mk(basesBytes, false, true);
+ const basesYBuf = mk(basesBytes, false, true);
+ const valIdxBuf = mk(valIdx.byteLength, false, true);
+ const rowPtrBuf = mk(rowPtr.byteLength, false, true);
+ const activeBuf = mk(activeBytes, true);
+ const countsBuf = mk(totalBuckets * 4, true);
+ const offsetsBuf = mk(totalBuckets * 4, true);
+ device.queue.writeBuffer(basesXBuf, 0, basesX);
+ device.queue.writeBuffer(basesYBuf, 0, basesY);
+ device.queue.writeBuffer(valIdxBuf, 0, valIdx);
+ device.queue.writeBuffer(rowPtrBuf, 0, rowPtr);
+
+ const paramsActiveBuf = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST });
+ device.queue.writeBuffer(paramsActiveBuf, 0, new Uint32Array([totalSlots, M, 0, 0]));
+ const paramsMetaBuf = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST });
+ device.queue.writeBuffer(paramsMetaBuf, 0, new Uint32Array([NUM_COLUMNS, totalBuckets, 0, 0]));
+
+ const activeLayout = device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+ const metaLayout = device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+ const activePipeline = await compileOne(device, sm.gen_csr_to_v2_active_sums_shader(WG), `csr-to-v2-active_sums-wg${WG}`, activeLayout);
+ const metaPipeline = await compileOne(device, sm.gen_csr_to_v2_meta_shader(WG), `csr-to-v2-meta-wg${WG}`, metaLayout);
+
+ const activeBind = device.createBindGroup({
+ layout: activeLayout,
+ entries: [
+ { binding: 0, resource: { buffer: valIdxBuf } },
+ { binding: 1, resource: { buffer: basesXBuf } },
+ { binding: 2, resource: { buffer: basesYBuf } },
+ { binding: 3, resource: { buffer: activeBuf } },
+ { binding: 4, resource: { buffer: paramsActiveBuf } },
+ ],
+ });
+ const metaBind = device.createBindGroup({
+ layout: metaLayout,
+ entries: [
+ { binding: 0, resource: { buffer: rowPtrBuf } },
+ { binding: 1, resource: { buffer: countsBuf } },
+ { binding: 2, resource: { buffer: offsetsBuf } },
+ { binding: 3, resource: { buffer: paramsMetaBuf } },
+ ],
+ });
+
+ const groupsActive = Math.ceil(totalSlots / WG);
+ const groupsMeta = Math.ceil(totalBuckets / WG);
+ log('info', `dispatch shape: active=${groupsActive} wgs, meta=${groupsMeta} wgs`);
+
+ // Warmup
+ {
+ const enc = device.createCommandEncoder();
+ let pass = enc.beginComputePass();
+ pass.setPipeline(activePipeline);
+ pass.setBindGroup(0, activeBind);
+ pass.dispatchWorkgroups(groupsActive, 1, 1);
+ pass.end();
+ pass = enc.beginComputePass();
+ pass.setPipeline(metaPipeline);
+ pass.setBindGroup(0, metaBind);
+ pass.dispatchWorkgroups(groupsMeta, 1, 1);
+ pass.end();
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ }
+ log('info', 'warmup done');
+
+ let validated = false;
+ if (VALIDATE) {
+ const gpuActive = await readbackU32(device, activeBuf, activeBytes);
+ const gpuCounts = await readbackU32(device, countsBuf, totalBuckets * 4);
+ const gpuOffsets = await readbackU32(device, offsetsBuf, totalBuckets * 4);
+ const refBasesX = new Uint32Array(basesX.buffer, basesX.byteOffset, INPUT_SIZE * 8);
+ const refBasesY = new Uint32Array(basesY.buffer, basesY.byteOffset, INPUT_SIZE * 8);
+
+ const mismatches: string[] = [];
+ let xFails = 0;
+ let yFails = 0;
+ const planeYU32Off = 2 * 2 * M;
+ for (let slot = 0; slot < totalSlots; slot++) {
+ const ptIdx = valIdx[slot];
+ for (let w = 0; w < 8; w++) {
+ const got = gpuActive[slot * 8 + w];
+ const want = refBasesX[ptIdx * 8 + w];
+ if (got !== want) {
+ xFails++;
+ if (mismatches.length < 8) mismatches.push(`activeX[slot=${slot} w=${w}]: gpu=${got} ref=${want}`);
+ }
+ const gotY = gpuActive[planeYU32Off + slot * 8 + w];
+ const wantY = refBasesY[ptIdx * 8 + w];
+ if (gotY !== wantY) {
+ yFails++;
+ if (mismatches.length < 8) mismatches.push(`activeY[slot=${slot} w=${w}]: gpu=${gotY} ref=${wantY}`);
+ }
+ }
+ }
+ let mFails = 0;
+ for (let s = 0; s < NUM_SUBTASKS; s++) {
+ const rpBase = s * (NUM_COLUMNS + 1);
+ const ccBase = s * NUM_COLUMNS;
+ for (let b = 0; b < NUM_COLUMNS; b++) {
+ const wantCount = rowPtr[rpBase + b + 1] - rowPtr[rpBase + b];
+ const wantOffset = rowPtr[rpBase + b];
+ if (gpuCounts[ccBase + b] !== wantCount) {
+ mFails++;
+ if (mismatches.length < 12) mismatches.push(`counts[s=${s} b=${b}]: gpu=${gpuCounts[ccBase + b]} ref=${wantCount}`);
+ }
+ if (gpuOffsets[ccBase + b] !== wantOffset) {
+ mFails++;
+ if (mismatches.length < 12) mismatches.push(`offsets[s=${s} b=${b}]: gpu=${gpuOffsets[ccBase + b]} ref=${wantOffset}`);
+ }
+ }
+ }
+ if (xFails === 0 && yFails === 0 && mFails === 0) {
+ validated = true;
+ log('ok', `validation: PASS — converter output byte-equivalent to host reference (${totalSlots} slots, ${totalBuckets} buckets)`);
+ } else {
+ log('err', `validation: FAIL — activeX mismatches=${xFails}, activeY=${yFails}, meta=${mFails}`);
+ for (const m of mismatches.slice(0, 12)) log('err', ` ${m}`);
+ }
+ }
+
+ // Timed: DISP back-to-back encoded as one submit per rep, split active vs meta.
+ const activeSamples: number[] = [];
+ const metaSamples: number[] = [];
+ const combinedSamples: number[] = [];
+ for (let r = 0; r < REPS; r++) {
+ {
+ const enc = device.createCommandEncoder();
+ for (let d = 0; d < DISP; d++) {
+ const pass = enc.beginComputePass();
+ pass.setPipeline(activePipeline);
+ pass.setBindGroup(0, activeBind);
+ pass.dispatchWorkgroups(groupsActive, 1, 1);
+ pass.end();
+ }
+ const t0 = performance.now();
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ activeSamples.push(performance.now() - t0);
+ }
+ {
+ const enc = device.createCommandEncoder();
+ for (let d = 0; d < DISP; d++) {
+ const pass = enc.beginComputePass();
+ pass.setPipeline(metaPipeline);
+ pass.setBindGroup(0, metaBind);
+ pass.dispatchWorkgroups(groupsMeta, 1, 1);
+ pass.end();
+ }
+ const t0 = performance.now();
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ metaSamples.push(performance.now() - t0);
+ }
+ {
+ const enc = device.createCommandEncoder();
+ for (let d = 0; d < DISP; d++) {
+ let pass = enc.beginComputePass();
+ pass.setPipeline(activePipeline);
+ pass.setBindGroup(0, activeBind);
+ pass.dispatchWorkgroups(groupsActive, 1, 1);
+ pass.end();
+ pass = enc.beginComputePass();
+ pass.setPipeline(metaPipeline);
+ pass.setBindGroup(0, metaBind);
+ pass.dispatchWorkgroups(groupsMeta, 1, 1);
+ pass.end();
+ }
+ const t0 = performance.now();
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ combinedSamples.push(performance.now() - t0);
+ }
+ }
+ const stat = (xs: number[]): { min: number; median: number; max: number } => {
+ const med = median(xs);
+ return { min: (Math.min(...xs) / DISP) * 1000, median: (med / DISP) * 1000, max: (Math.max(...xs) / DISP) * 1000 };
+ };
+ const activeStats = stat(activeSamples);
+ const metaStats = stat(metaSamples);
+ const combinedStats = stat(combinedSamples);
+ const nsPerSlotCombined = (combinedStats.median * 1000) / totalSlots;
+ log(
+ 'ok',
+ `active_sums per-dispatch: median=${activeStats.median.toFixed(2)}μs min=${activeStats.min.toFixed(2)}μs max=${activeStats.max.toFixed(2)}μs`,
+ );
+ log(
+ 'ok',
+ `meta per-dispatch: median=${metaStats.median.toFixed(2)}μs min=${metaStats.min.toFixed(2)}μs max=${metaStats.max.toFixed(2)}μs`,
+ );
+ log(
+ 'ok',
+ `combined per-dispatch: median=${combinedStats.median.toFixed(2)}μs min=${combinedStats.min.toFixed(2)}μs max=${combinedStats.max.toFixed(2)}μs ` +
+ `→ ${nsPerSlotCombined.toFixed(3)} ns/slot (${totalSlots} slots)`,
+ );
+
+ basesXBuf.destroy();
+ basesYBuf.destroy();
+ valIdxBuf.destroy();
+ rowPtrBuf.destroy();
+ activeBuf.destroy();
+ countsBuf.destroy();
+ offsetsBuf.destroy();
+ paramsActiveBuf.destroy();
+ paramsMetaBuf.destroy();
+
+ return {
+ num_subtasks: NUM_SUBTASKS,
+ num_columns: NUM_COLUMNS,
+ input_size: INPUT_SIZE,
+ total_slots: totalSlots,
+ total_buckets: totalBuckets,
+ wg: WG,
+ disp: DISP,
+ reps: REPS,
+ active_sums_us: activeStats,
+ meta_us: metaStats,
+ combined_us: combinedStats,
+ ns_per_slot_combined: nsPerSlotCombined,
+ validated,
+ };
+}
+
+function parseParams() {
+ const qp = new URLSearchParams(window.location.search);
+ if (qp.get('subtasks')) NUM_SUBTASKS = parseInt(qp.get('subtasks')!, 10);
+ if (qp.get('columns')) NUM_COLUMNS = parseInt(qp.get('columns')!, 10);
+ if (qp.get('input')) INPUT_SIZE = parseInt(qp.get('input')!, 10);
+ if (qp.get('wg')) WG = parseInt(qp.get('wg')!, 10);
+ if (qp.get('disp')) DISP = parseInt(qp.get('disp')!, 10);
+ if (qp.get('reps')) REPS = parseInt(qp.get('reps')!, 10);
+ if (qp.get('seed')) SEED = parseInt(qp.get('seed')!, 10);
+ if (qp.get('validate') === '1') VALIDATE = true;
+ return {
+ subtasks: NUM_SUBTASKS,
+ columns: NUM_COLUMNS,
+ input_size: INPUT_SIZE,
+ wg: WG,
+ disp: DISP,
+ reps: REPS,
+ seed: SEED,
+ validate: VALIDATE,
+ };
+}
+
+async function main() {
+ try {
+ if (!('gpu' in navigator)) throw new Error('navigator.gpu missing');
+ const params = parseParams();
+ benchState.params = params;
+ log('info', `params: ${JSON.stringify(params)}`);
+ benchState.state = 'running';
+ const device = await get_device();
+ log('info', 'WebGPU device acquired');
+ const sm = new ShaderManager(NUM_SUBTASKS, NUM_COLUMNS, BN254_CURVE_CONFIG, false);
+ const r = await runOne(device, sm);
+ benchState.results.push(r);
+ resultsClient.postProgress({
+ kind: 'csr_to_v2_done',
+ active_sums_us: r.active_sums_us,
+ meta_us: r.meta_us,
+ combined_us: r.combined_us,
+ ns_per_slot_combined: r.ns_per_slot_combined,
+ validated: r.validated,
+ });
+ benchState.state = 'done';
+ log('ok', 'done');
+ } catch (e) {
+ const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e);
+ log('err', `FATAL: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ }
+}
+
+main()
+ .catch(e => {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `unhandled: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ })
+ .finally(() => {
+ postFinal().catch(() => {});
+ });
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.html b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.html
new file mode 100644
index 000000000000..9605a79b5605
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.html
@@ -0,0 +1,37 @@
+
+
+
+
+ Workgroup-scan fused batch-affine round bench (WebGPU)
+
+
+
+ Workgroup-scan fused batch-affine round bench (WebGPU)
+ Query params: ?reps=R&total=N&sizes=A,B,C&skip_correctness=1
+
+
+
+
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts
new file mode 100644
index 000000000000..264bbc0a6f77
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts
@@ -0,0 +1,549 @@
+///
+// Standalone WebGPU bench + correctness oracle for the new
+// `batch_affine_fused_wg_scan` kernel (workgroup-scan fused batch-affine
+// round, validated to mirror bench_batch_affine's 22 ns/pair design with
+// bucket-indirect loads via pair_target_meta).
+//
+// Inputs: TOTAL_PAIRS on-curve BN254 G1 affine pairs (P_i, Q_i),
+// distributed across ONE synthetic subtask. Each pair maps to a unique
+// bucket (`pair_target_meta[2i] = i`, `pair_target_meta[2i+1] = i`,
+// `val_idx[i] = i`), so the scheduler's "distinct-buckets within a
+// subtask round" invariant holds trivially. The kernel writes the
+// affine sum `R_i = P_i + Q_i` to running_x[i] / running_y[i]; we
+// decode packed Mont form back to canonical and compare to noble's
+// reference P.add(Q).
+//
+// SAFETY: NO MSM pipeline touched. The shader uses only compile-time-
+// const loop bounds (BS, TPB, NUM_WORDS). One dispatch per measurement.
+
+import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js';
+import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js';
+import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js';
+import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js';
+import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js';
+import { makeResultsClient } from './results_post.js';
+import { bn254 } from '@noble/curves/bn254';
+
+const G1 = bn254.G1.ProjectivePoint;
+const FR_ORDER = bn254.fields.Fr.ORDER;
+
+const DEFAULT_TOTAL_PAIRS = 1 << 16;
+let TOTAL_PAIRS = DEFAULT_TOTAL_PAIRS;
+
+const DEFAULT_BATCH_SIZES = [256, 512, 1024, 2048] as const;
+let BATCH_SIZES: readonly number[] = DEFAULT_BATCH_SIZES;
+
+let SKIP_CORRECTNESS = false;
+
+function tpbBsFor(batchSize: number): { tpb: number; bs: number } {
+ if (batchSize <= 64) return { tpb: batchSize, bs: 1 };
+ return { tpb: 64, bs: batchSize / 64 };
+}
+
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+function randomBelow(p: bigint, rng: () => number): bigint {
+ const bitlen = p.toString(2).length;
+ const byteLen = Math.ceil(bitlen / 8);
+ for (;;) {
+ let v = 0n;
+ for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff);
+ v &= (1n << BigInt(bitlen)) - 1n;
+ if (v > 0n && v < p) return v;
+ }
+}
+
+function median(xs: number[]): number {
+ if (xs.length === 0) return NaN;
+ const s = xs.slice().sort((a, b) => a - b);
+ return s[Math.floor(s.length / 2)];
+}
+
+function biToLe32u32(v: bigint): Uint32Array {
+ const out = new Uint32Array(8);
+ let x = v;
+ for (let i = 0; i < 8; i++) {
+ out[i] = Number(x & 0xffffffffn);
+ x >>= 32n;
+ }
+ return out;
+}
+
+function le32u32ToBi(u32: Uint32Array, off: number): bigint {
+ let v = 0n;
+ for (let i = 7; i >= 0; i--) v = (v << 32n) | BigInt(u32[off + i] >>> 0);
+ return v;
+}
+
+interface PerSizeResult {
+ batch_size: number;
+ tpb: number;
+ bs: number;
+ num_wgs: number;
+ total_pairs: number;
+ median_ms: number;
+ min_ms: number;
+ max_ms: number;
+ ns_per_pair: number;
+ samples_ms: number[];
+ correctness: 'pass' | 'fail' | 'skipped';
+ correctness_first_fail?: { i: number; expected_x: string; got_x: string; expected_y: string; got_y: string };
+}
+
+interface BenchState {
+ state: 'boot' | 'running' | 'done' | 'error';
+ params: { reps: number; total: number; sizes: readonly number[]; skip_correctness: boolean } | null;
+ results: PerSizeResult[];
+ error: string | null;
+ log: string[];
+}
+
+const benchState: BenchState = {
+ state: 'boot',
+ params: null,
+ results: [],
+ error: null,
+ log: [],
+};
+(window as unknown as { __bench: BenchState }).__bench = benchState;
+
+const resultsClient = makeResultsClient({ page: 'bench-fused-wg-scan' });
+(window as unknown as { __runId: string }).__runId = resultsClient.runId;
+
+async function postFinal(): Promise {
+ await resultsClient.postResults({
+ state: benchState.state,
+ params: benchState.params,
+ results: benchState.results,
+ error: benchState.error,
+ log: benchState.log,
+ userAgent: navigator.userAgent,
+ hardwareConcurrency: navigator.hardwareConcurrency,
+ });
+}
+
+const $log = document.getElementById('log') as HTMLDivElement;
+function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) {
+ const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : '';
+ const span = document.createElement('div');
+ span.className = cls;
+ span.textContent = msg;
+ $log.appendChild(span);
+ benchState.log.push(`[${level}] ${msg}`);
+ console.log(`[bench-fused-wg-scan] ${msg}`);
+}
+
+async function createPipeline(
+ device: GPUDevice,
+ code: string,
+ cacheKey: string,
+): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> {
+ const module = device.createShaderModule({ code });
+ const info = await module.getCompilationInfo();
+ let hasError = false;
+ const errLines: string[] = [];
+ for (const msg of info.messages) {
+ const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`;
+ if (msg.type === 'error') {
+ console.error(line);
+ log('err', line);
+ errLines.push(line);
+ hasError = true;
+ } else {
+ console.warn(line);
+ log('warn', line);
+ }
+ }
+ if (hasError) {
+ throw new Error(`WGSL compile failed for ${cacheKey}: ${errLines.join(' | ')}`);
+ }
+ const layout = device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 5, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 6, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 7, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 8, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+ const pipeline = await device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+ return { pipeline, layout };
+}
+
+interface PointPair {
+ p: { x: bigint; y: bigint };
+ q: { x: bigint; y: bigint };
+ r: { x: bigint; y: bigint };
+}
+
+function buildPairs(n: number, seed: number): PointPair[] {
+ const rng = makeRng(seed);
+ const out: PointPair[] = [];
+ for (let i = 0; i < n; i++) {
+ let p: { x: bigint; y: bigint };
+ let q: { x: bigint; y: bigint };
+ let r: { x: bigint; y: bigint };
+ for (;;) {
+ const sp = randomBelow(FR_ORDER, rng);
+ const sq = randomBelow(FR_ORDER, rng);
+ const pp = G1.BASE.multiply(sp);
+ const qp = G1.BASE.multiply(sq);
+ if (pp.is0() || qp.is0()) continue;
+ const pa = pp.toAffine();
+ const qa = qp.toAffine();
+ if (pa.x === qa.x) continue;
+ const rp = pp.add(qp);
+ if (rp.is0()) continue;
+ const ra = rp.toAffine();
+ p = pa;
+ q = qa;
+ r = ra;
+ break;
+ }
+ out.push({ p, q, r });
+ }
+ return out;
+}
+
+async function runOne(
+ device: GPUDevice,
+ sm: ShaderManager,
+ batchSize: number,
+ reps: number,
+ R: bigint,
+ p: bigint,
+ pairs: PointPair[],
+): Promise {
+ const { tpb, bs } = tpbBsFor(batchSize);
+ if (TOTAL_PAIRS % batchSize !== 0) {
+ throw new Error(`TOTAL_PAIRS=${TOTAL_PAIRS} must be a multiple of batch_size=${batchSize}`);
+ }
+ const numWgs = TOTAL_PAIRS / batchSize;
+ log('info', `=== batch_size=${batchSize}: TPB=${tpb} BS=${bs} num_WGs=${numWgs}`);
+
+ const code = sm.gen_batch_affine_fused_wg_scan_shader(tpb, bs);
+ const cacheKey = `bench-fused-wg-scan-T${tpb}-S${bs}`;
+ log('info', `compiling shader (${code.length} chars)`);
+ (window as unknown as Record)[`__shader_${batchSize}`] = code;
+ const { pipeline, layout } = await createPipeline(device, code, cacheKey);
+
+ const fieldBytes = 32;
+ const fieldsPerPair = 2;
+ const fieldsTotalIo = TOTAL_PAIRS * fieldsPerPair;
+
+ const newPointXAB = new ArrayBuffer(TOTAL_PAIRS * fieldBytes);
+ const newPointYAB = new ArrayBuffer(TOTAL_PAIRS * fieldBytes);
+ const runningXAB = new ArrayBuffer(TOTAL_PAIRS * fieldBytes);
+ const runningYAB = new ArrayBuffer(TOTAL_PAIRS * fieldBytes);
+
+ const nx32 = new Uint32Array(newPointXAB);
+ const ny32 = new Uint32Array(newPointYAB);
+ const rx32 = new Uint32Array(runningXAB);
+ const ry32 = new Uint32Array(runningYAB);
+
+ for (let i = 0; i < TOTAL_PAIRS; i++) {
+ const { p: pp, q: qq } = pairs[i];
+ const pxM = (pp.x * R) % p;
+ const pyM = (pp.y * R) % p;
+ const qxM = (qq.x * R) % p;
+ const qyM = (qq.y * R) % p;
+ rx32.set(biToLe32u32(pxM), i * 8);
+ ry32.set(biToLe32u32(pyM), i * 8);
+ nx32.set(biToLe32u32(qxM), i * 8);
+ ny32.set(biToLe32u32(qyM), i * 8);
+ }
+
+ const valIdxAB = new Uint32Array(TOTAL_PAIRS);
+ const ptmAB = new Uint32Array(TOTAL_PAIRS * 2);
+ for (let i = 0; i < TOTAL_PAIRS; i++) {
+ valIdxAB[i] = i;
+ ptmAB[2 * i] = i;
+ ptmAB[2 * i + 1] = i;
+ }
+ const countAB = new Uint32Array([TOTAL_PAIRS]);
+ const paramsAB = new Uint32Array([TOTAL_PAIRS, TOTAL_PAIRS, 0, 0]);
+
+ const mkSb = (size: number, copyDst = true, copySrc = false): GPUBuffer => {
+ let usage = GPUBufferUsage.STORAGE;
+ if (copyDst) usage |= GPUBufferUsage.COPY_DST;
+ if (copySrc) usage |= GPUBufferUsage.COPY_SRC;
+ return device.createBuffer({ size, usage });
+ };
+
+ const valIdxBuf = mkSb(valIdxAB.byteLength);
+ const newPointXBuf = mkSb(newPointXAB.byteLength);
+ const newPointYBuf = mkSb(newPointYAB.byteLength);
+ const runningXBuf = mkSb(runningXAB.byteLength, true, true);
+ const runningYBuf = mkSb(runningYAB.byteLength, true, true);
+ const ptmBuf = mkSb(ptmAB.byteLength);
+ const prefixBuf = mkSb(TOTAL_PAIRS * 20 * 4, false);
+ const countBuf = mkSb(countAB.byteLength);
+ const paramsBuf = device.createBuffer({
+ size: 16,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
+ });
+
+ device.queue.writeBuffer(valIdxBuf, 0, valIdxAB);
+ device.queue.writeBuffer(newPointXBuf, 0, newPointXAB);
+ device.queue.writeBuffer(newPointYBuf, 0, newPointYAB);
+ device.queue.writeBuffer(ptmBuf, 0, ptmAB);
+ device.queue.writeBuffer(countBuf, 0, countAB);
+ device.queue.writeBuffer(paramsBuf, 0, paramsAB);
+
+ const bindGroup = device.createBindGroup({
+ layout,
+ entries: [
+ { binding: 0, resource: { buffer: valIdxBuf } },
+ { binding: 1, resource: { buffer: newPointXBuf } },
+ { binding: 2, resource: { buffer: newPointYBuf } },
+ { binding: 3, resource: { buffer: runningXBuf } },
+ { binding: 4, resource: { buffer: runningYBuf } },
+ { binding: 5, resource: { buffer: ptmBuf } },
+ { binding: 6, resource: { buffer: prefixBuf } },
+ { binding: 7, resource: { buffer: countBuf } },
+ { binding: 8, resource: { buffer: paramsBuf } },
+ ],
+ });
+
+ device.queue.writeBuffer(runningXBuf, 0, runningXAB);
+ device.queue.writeBuffer(runningYBuf, 0, runningYAB);
+
+ const dispatch = async (resetState: boolean) => {
+ if (resetState) {
+ device.queue.writeBuffer(runningXBuf, 0, runningXAB);
+ device.queue.writeBuffer(runningYBuf, 0, runningYAB);
+ await device.queue.onSubmittedWorkDone();
+ }
+ const encoder = device.createCommandEncoder();
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bindGroup);
+ pass.dispatchWorkgroups(numWgs, 1, 1);
+ pass.end();
+ const t0 = performance.now();
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ return performance.now() - t0;
+ };
+
+ await dispatch(true);
+ log('info', 'warmup dispatch returned');
+
+ let correctness: 'pass' | 'fail' | 'skipped' = 'skipped';
+ let correctness_first_fail: PerSizeResult['correctness_first_fail'];
+
+ if (!SKIP_CORRECTNESS) {
+ const stagingX = device.createBuffer({
+ size: TOTAL_PAIRS * fieldBytes,
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
+ });
+ const stagingY = device.createBuffer({
+ size: TOTAL_PAIRS * fieldBytes,
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
+ });
+ const enc = device.createCommandEncoder();
+ enc.copyBufferToBuffer(runningXBuf, 0, stagingX, 0, TOTAL_PAIRS * fieldBytes);
+ enc.copyBufferToBuffer(runningYBuf, 0, stagingY, 0, TOTAL_PAIRS * fieldBytes);
+ device.queue.submit([enc.finish()]);
+ await Promise.all([stagingX.mapAsync(GPUMapMode.READ), stagingY.mapAsync(GPUMapMode.READ)]);
+ const gpuX = new Uint32Array(stagingX.getMappedRange().slice(0));
+ const gpuY = new Uint32Array(stagingY.getMappedRange().slice(0));
+ stagingX.unmap();
+ stagingY.unmap();
+ stagingX.destroy();
+ stagingY.destroy();
+
+ const rInv = (() => {
+ let g = R % p;
+ let r = p;
+ let x = 0n;
+ let y = 1n;
+ while (g !== 0n) {
+ const q = r / g;
+ [r, g] = [g, r - q * g];
+ [x, y] = [y, x - q * y];
+ }
+ return ((x % p) + p) % p;
+ })();
+
+ let mismatches = 0;
+ const MAX_REPORTED = 4;
+ for (let i = 0; i < TOTAL_PAIRS; i++) {
+ const gxM = le32u32ToBi(gpuX, i * 8);
+ const gyM = le32u32ToBi(gpuY, i * 8);
+ const gx = (gxM * rInv) % p;
+ const gy = (gyM * rInv) % p;
+ const ex = pairs[i].r.x;
+ const ey = pairs[i].r.y;
+ if (gx !== ex || gy !== ey) {
+ mismatches++;
+ if (!correctness_first_fail) {
+ correctness_first_fail = {
+ i,
+ expected_x: ex.toString(),
+ got_x: gx.toString(),
+ expected_y: ey.toString(),
+ got_y: gy.toString(),
+ };
+ }
+ if (mismatches > MAX_REPORTED) break;
+ }
+ }
+ correctness = mismatches === 0 ? 'pass' : 'fail';
+ if (mismatches === 0) {
+ log('ok', `correctness: pass (${TOTAL_PAIRS}/${TOTAL_PAIRS} pairs match noble reference)`);
+ } else {
+ log('err', `correctness: FAIL (${mismatches}+ mismatches; first @ i=${correctness_first_fail!.i})`);
+ log('err', ` expected R.x = ${correctness_first_fail!.expected_x.slice(0, 24)}...`);
+ log('err', ` got R.x = ${correctness_first_fail!.got_x.slice(0, 24)}...`);
+ }
+ }
+
+ const samples: number[] = [];
+ for (let r = 0; r < reps; r++) {
+ samples.push(await dispatch(false));
+ }
+ const med = median(samples);
+ const mn = Math.min(...samples);
+ const mx = Math.max(...samples);
+ const nsPerPair = (med * 1e6) / TOTAL_PAIRS;
+
+ log(
+ correctness === 'fail' ? 'err' : 'ok',
+ `B=${batchSize}: median=${med.toFixed(3)}ms min=${mn.toFixed(3)}ms max=${mx.toFixed(3)}ms ns/pair=${nsPerPair.toFixed(1)} correctness=${correctness}`,
+ );
+
+ valIdxBuf.destroy();
+ newPointXBuf.destroy();
+ newPointYBuf.destroy();
+ runningXBuf.destroy();
+ runningYBuf.destroy();
+ ptmBuf.destroy();
+ prefixBuf.destroy();
+ countBuf.destroy();
+ paramsBuf.destroy();
+
+ return {
+ batch_size: batchSize,
+ tpb,
+ bs,
+ num_wgs: numWgs,
+ total_pairs: TOTAL_PAIRS,
+ median_ms: med,
+ min_ms: mn,
+ max_ms: mx,
+ ns_per_pair: nsPerPair,
+ samples_ms: samples,
+ correctness,
+ correctness_first_fail,
+ };
+}
+
+function parseParams() {
+ const qp = new URLSearchParams(window.location.search);
+ const reps = parseInt(qp.get('reps') ?? '5', 10);
+ if (!Number.isFinite(reps) || reps <= 0 || reps > 50) {
+ throw new Error(`?reps must be in (0, 50], got ${qp.get('reps')}`);
+ }
+ const totalStr = qp.get('total');
+ if (totalStr !== null) {
+ const total = parseInt(totalStr, 10);
+ if (!Number.isFinite(total) || total <= 0 || total > (1 << 20)) {
+ throw new Error(`?total must be in (0, 2^20], got ${totalStr}`);
+ }
+ TOTAL_PAIRS = total;
+ }
+ const sizesStr = qp.get('sizes');
+ if (sizesStr !== null) {
+ const sizes = sizesStr.split(',').map(s => parseInt(s, 10));
+ for (const s of sizes) {
+ if (!Number.isFinite(s) || s <= 0 || s > 4096) {
+ throw new Error(`?sizes entries must be in (0, 4096], got ${s}`);
+ }
+ if (TOTAL_PAIRS % s !== 0) {
+ throw new Error(`?sizes entry ${s} does not divide TOTAL_PAIRS=${TOTAL_PAIRS}`);
+ }
+ const { tpb } = tpbBsFor(s);
+ if ((tpb & (tpb - 1)) !== 0) {
+ throw new Error(`?sizes entry ${s} yields TPB=${tpb} which is not a power of two`);
+ }
+ }
+ BATCH_SIZES = sizes;
+ }
+ if (qp.get('skip_correctness') === '1') {
+ SKIP_CORRECTNESS = true;
+ }
+ return { reps, total: TOTAL_PAIRS, sizes: BATCH_SIZES, skip_correctness: SKIP_CORRECTNESS };
+}
+
+async function main() {
+ try {
+ if (!('gpu' in navigator)) {
+ throw new Error('navigator.gpu missing — WebGPU not available');
+ }
+ const params = parseParams();
+ benchState.params = params;
+ log('info', `params: reps=${params.reps} total=${params.total} sizes=[${params.sizes.join(',')}] skip_correctness=${params.skip_correctness}`);
+
+ benchState.state = 'running';
+ const device = await get_device();
+ log('info', 'WebGPU device acquired');
+
+ const p = BN254_BASE_FIELD;
+ const miscParams = compute_misc_params(p, 13);
+ const R = miscParams.r;
+
+ log('info', `generating ${TOTAL_PAIRS} on-curve pairs via noble (this can take a few seconds)…`);
+ const t0 = performance.now();
+ const pairs = buildPairs(TOTAL_PAIRS, 0xc0ffee);
+ log('info', `pair generation done in ${(performance.now() - t0).toFixed(0)} ms`);
+
+ const sm = new ShaderManager(4, TOTAL_PAIRS, BN254_CURVE_CONFIG, false);
+
+ for (const B of BATCH_SIZES) {
+ try {
+ const r = await runOne(device, sm, B, params.reps, R, p, pairs);
+ benchState.results.push(r);
+ resultsClient.postProgress({ kind: 'batch_done', batch_size: B, median_ms: r.median_ms, ns_per_pair: r.ns_per_pair, correctness: r.correctness });
+ } catch (e) {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `B=${B} failed: ${msg} — STOPPING sweep at first failure`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ return;
+ }
+ }
+
+ benchState.state = 'done';
+ log('ok', 'all batches done');
+ } catch (e) {
+ const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e);
+ log('err', `FATAL: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ }
+}
+
+main()
+ .catch(e => {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `unhandled: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ })
+ .finally(() => {
+ postFinal().catch(() => {});
+ });
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-chain.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-chain.html
new file mode 100644
index 000000000000..423d124585c3
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-chain.html
@@ -0,0 +1,37 @@
+
+
+
+
+ MSM bucket-accumulate pair-tree level-0 bench (marshal+chain)
+
+
+
+ MSM bucket-accumulate pair-tree level-0 bench (marshal + chain)
+ Query params: ?reps=R&pairs=N&buckets=B&s=S&wgi=W&disp=D
+
+
+
+
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-chain.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-chain.ts
new file mode 100644
index 000000000000..b644b8ab1598
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-chain.ts
@@ -0,0 +1,625 @@
+///
+// bench-msm-chain — Standalone WebGPU bench for the marshal + chain
+// pair-tree level-0 pipeline that integrates the ba_rev_packed_carry
+// chain kernel into MSM bucket accumulate.
+//
+// The harness:
+// 1. Generates a point pool of N+1 random Montgomery points in the
+// SoA-packed layout (2 planes, PG=2 vec4/elem). Index 0 is a
+// universal decoy seed; indices 1..N are real points.
+// 2. Generates a synthetic CSR: B buckets, each point assigned to a
+// uniformly random bucket. csr_indices = points sorted by bucket;
+// offset[] = CSR row pointers; count[b] = bucket b's point count.
+// 3. Builds a chunk plan from the dense slices: for each bucket b
+// with count[b] >= S, splits it into floor(count[b]/S) chunks of
+// exactly S points. Total dense chunks T_dense.
+// 4. Runs the marshal kernel: T_dense threads, each gathers S point
+// coords from the pool into the strided chain layout.
+// 5. Runs the recovered ba_rev_packed_carry chain kernel on the
+// marshaled layout.
+// 6. Times marshal and chain separately (DISP back-to-back dispatches
+// per timed sample, amortising submit + drain).
+//
+// Reported metrics, sweeping S in {16, 32, 64}:
+// marshal_ns_per_pt — ns per input point processed by marshal
+// chain_ns_per_pt — ns per input point processed by chain
+// combined_ns_per_pt — marshal + chain
+// density — fraction of N points covered by dense chunks
+//
+// Out of scope here (covered by follow-on passes):
+// - Reduction passes that fold pair-tree levels 1..log2(S)/2 into
+// per-bucket totals. These reuse the same chain kernel with
+// decreasing T; cost is ~25 ns/pt per level, ~3 levels for S=16.
+// - Tail handling for buckets with count < S. Will reuse the
+// workgroup-scan batch-affine path or a variable-length variant.
+
+import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js';
+import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js';
+import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js';
+import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js';
+import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js';
+import { makeResultsClient } from './results_post.js';
+
+const PG = 2;
+const DEFAULT_PAIRS = 1 << 17; // 131072
+const DEFAULT_BUCKETS = 1 << 13; // 8192 -> avg bucket size = 16
+const DEFAULT_WGI = 64;
+const DEFAULT_DISP = 8;
+const DEFAULT_S_SWEEP: readonly number[] = [16, 32, 64];
+
+let PAIRS = DEFAULT_PAIRS;
+let BUCKETS = DEFAULT_BUCKETS;
+let WGI = DEFAULT_WGI;
+let DISP = DEFAULT_DISP;
+let S_SWEEP: readonly number[] = DEFAULT_S_SWEEP;
+
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+function randomBelow(p: bigint, rng: () => number): bigint {
+ const bitlen = p.toString(2).length;
+ const byteLen = Math.ceil(bitlen / 8);
+ for (;;) {
+ let v = 0n;
+ for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff);
+ v &= (1n << BigInt(bitlen)) - 1n;
+ if (v > 0n && v < p) return v;
+ }
+}
+
+function bigintToPackedU32x8(v: bigint): Uint32Array {
+ const w = new Uint32Array(8);
+ let x = v;
+ for (let i = 0; i < 8; i++) {
+ w[i] = Number(x & 0xffffffffn);
+ x >>= 32n;
+ }
+ return w;
+}
+
+function median(xs: number[]): number {
+ if (xs.length === 0) return NaN;
+ const s = xs.slice().sort((a, b) => a - b);
+ return s[Math.floor(s.length / 2)];
+}
+
+// Pack a pool of M (= N+1) random Montgomery points into the SoA layout
+// the marshal kernel expects: 2 planes (P.x, P.y), PG vec4 per element.
+// Pool index 0 is the decoy seed.
+function buildPointPool(poolSize: number, R: bigint, p: bigint, rng: () => number): Uint32Array {
+ const M = poolSize;
+ const buf = new Uint32Array(2 * PG * M * 4);
+ for (let e = 0; e < M; e++) {
+ const x = (randomBelow(p, rng) * R) % p;
+ const y = (randomBelow(p, rng) * R) % p;
+ const wx = bigintToPackedU32x8(x);
+ const wy = bigintToPackedU32x8(y);
+ for (let v = 0; v < PG; v++) {
+ const baseX = ((0 * PG + v) * M + e) * 4;
+ const baseY = ((1 * PG + v) * M + e) * 4;
+ buf[baseX + 0] = wx[4 * v + 0];
+ buf[baseX + 1] = wx[4 * v + 1];
+ buf[baseX + 2] = wx[4 * v + 2];
+ buf[baseX + 3] = wx[4 * v + 3];
+ buf[baseY + 0] = wy[4 * v + 0];
+ buf[baseY + 1] = wy[4 * v + 1];
+ buf[baseY + 2] = wy[4 * v + 2];
+ buf[baseY + 3] = wy[4 * v + 3];
+ }
+ }
+ return buf;
+}
+
+// Synthetic CSR: assigns each of the N points to a random bucket in
+// [0, B). Returns csr_indices (point indices sorted by bucket, values in
+// [1, N+1)) and offsets[b] giving the first csr_indices position for
+// bucket b. Point indices are 1-based so index 0 in the pool is reserved
+// for the decoy.
+function buildSyntheticCSR(
+ N: number,
+ B: number,
+ rng: () => number,
+): { csrIndices: Uint32Array; offsets: Uint32Array; counts: Uint32Array } {
+ const bucket = new Uint32Array(N);
+ const counts = new Uint32Array(B);
+ for (let i = 0; i < N; i++) {
+ const b = rng() % B;
+ bucket[i] = b;
+ counts[b]++;
+ }
+ const offsets = new Uint32Array(B + 1);
+ for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b];
+ const cursor = new Uint32Array(B);
+ const csrIndices = new Uint32Array(N);
+ for (let i = 0; i < N; i++) {
+ const b = bucket[i];
+ csrIndices[offsets[b] + cursor[b]++] = i + 1; // 1-based: 0 is decoy
+ }
+ return { csrIndices, offsets, counts };
+}
+
+// Walk the CSR row pointers and produce the dense chunk plan. For each
+// bucket b with count[b] >= S, emit floor(count[b]/S) chunks each
+// pointing at S consecutive csr_indices entries. tail = sum of leftover
+// points (count[b] mod S) that this v1 bench skips.
+function buildChunkPlan(
+ offsets: Uint32Array,
+ counts: Uint32Array,
+ S: number,
+): { chunkPlan: Uint32Array; T: number; tailPoints: number } {
+ const B = counts.length;
+ let T = 0;
+ let tail = 0;
+ for (let b = 0; b < B; b++) {
+ const c = counts[b];
+ T += Math.floor(c / S);
+ tail += c % S;
+ }
+ const chunkPlan = new Uint32Array(2 * T);
+ let t = 0;
+ for (let b = 0; b < B; b++) {
+ const c = counts[b];
+ const nChunks = Math.floor(c / S);
+ for (let k = 0; k < nChunks; k++) {
+ chunkPlan[2 * t + 0] = b;
+ chunkPlan[2 * t + 1] = offsets[b] + k * S;
+ t++;
+ }
+ }
+ return { chunkPlan, T, tailPoints: tail };
+}
+
+interface PerSizeResult {
+ s: number;
+ wgi: number;
+ pairs: number;
+ buckets: number;
+ T: number;
+ tail_points: number;
+ density: number;
+ disp: number;
+ marshal_median_ms: number;
+ marshal_min_ms: number;
+ marshal_max_ms: number;
+ marshal_ns_per_pt: number;
+ chain_median_ms: number;
+ chain_min_ms: number;
+ chain_max_ms: number;
+ chain_ns_per_pt: number;
+ combined_ns_per_pt: number;
+ marshal_samples_ms: number[];
+ chain_samples_ms: number[];
+ sanity_ok: boolean;
+}
+
+interface BenchState {
+ state: 'boot' | 'running' | 'done' | 'error';
+ params: { reps: number; pairs: number; buckets: number; wgi: number; disp: number; s_sweep: readonly number[] } | null;
+ results: PerSizeResult[];
+ error: string | null;
+ log: string[];
+}
+
+const benchState: BenchState = {
+ state: 'boot',
+ params: null,
+ results: [],
+ error: null,
+ log: [],
+};
+(window as unknown as { __bench: BenchState }).__bench = benchState;
+
+const resultsClient = makeResultsClient({ page: 'bench-msm-chain' });
+(window as unknown as { __runId: string }).__runId = resultsClient.runId;
+
+async function postFinal(): Promise {
+ await resultsClient.postResults({
+ state: benchState.state,
+ params: benchState.params,
+ results: benchState.results,
+ error: benchState.error,
+ log: benchState.log,
+ userAgent: navigator.userAgent,
+ hardwareConcurrency: navigator.hardwareConcurrency,
+ });
+}
+
+const $log = document.getElementById('log') as HTMLDivElement;
+function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) {
+ const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : '';
+ const span = document.createElement('div');
+ span.className = cls;
+ span.textContent = msg;
+ $log.appendChild(span);
+ benchState.log.push(`[${level}] ${msg}`);
+ console.log(`[bench-msm-chain] ${msg}`);
+}
+
+interface PipelineInfo {
+ pipeline: GPUComputePipeline;
+ layout: GPUBindGroupLayout;
+}
+
+async function compile(
+ device: GPUDevice,
+ code: string,
+ cacheKey: string,
+ bindLayout: GPUBindGroupLayout,
+): Promise {
+ const module = device.createShaderModule({ code });
+ const info = await module.getCompilationInfo();
+ let hasError = false;
+ const errLines: string[] = [];
+ for (const msg of info.messages) {
+ const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`;
+ if (msg.type === 'error') {
+ console.error(line);
+ log('err', line);
+ errLines.push(line);
+ hasError = true;
+ } else {
+ console.warn(line);
+ log('warn', line);
+ }
+ }
+ if (hasError) {
+ throw new Error(`WGSL compile failed for ${cacheKey}: ${errLines.join(' | ')}`);
+ }
+ return device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [bindLayout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+}
+
+function chainLayout(device: GPUDevice): GPUBindGroupLayout {
+ return device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+}
+
+function marshalLayout(device: GPUDevice): GPUBindGroupLayout {
+ return device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+}
+
+async function timeDispatchPasses(
+ device: GPUDevice,
+ pipeline: GPUComputePipeline,
+ bind: GPUBindGroup,
+ numWgs: number,
+ reps: number,
+ passes: number,
+): Promise {
+ // warmup
+ {
+ const enc = device.createCommandEncoder();
+ for (let pIdx = 0; pIdx < passes; pIdx++) {
+ const pass = enc.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bind);
+ pass.dispatchWorkgroups(numWgs, 1, 1);
+ pass.end();
+ }
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ }
+ const samples: number[] = [];
+ for (let r = 0; r < reps; r++) {
+ const enc = device.createCommandEncoder();
+ for (let pIdx = 0; pIdx < passes; pIdx++) {
+ const pass = enc.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bind);
+ pass.dispatchWorkgroups(numWgs, 1, 1);
+ pass.end();
+ }
+ const t0 = performance.now();
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ samples.push(performance.now() - t0);
+ }
+ return samples;
+}
+
+async function readNonZero(device: GPUDevice, buf: GPUBuffer, u32Count: number): Promise {
+ const bytes = u32Count * 4;
+ const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST });
+ const enc = device.createCommandEncoder();
+ enc.copyBufferToBuffer(buf, 0, staging, 0, bytes);
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await staging.mapAsync(GPUMapMode.READ);
+ const u32 = new Uint32Array(staging.getMappedRange().slice(0));
+ staging.unmap();
+ staging.destroy();
+ for (let i = 0; i < u32.length; i++) if (u32[i] !== 0) return true;
+ return false;
+}
+
+async function runOne(
+ device: GPUDevice,
+ sm: ShaderManager,
+ s: number,
+ reps: number,
+ R: bigint,
+ p: bigint,
+ seed: number,
+): Promise {
+ log('info', `=== S=${s}: PAIRS=${PAIRS} BUCKETS=${BUCKETS} WGI=${WGI} DISP=${DISP}`);
+
+ const rng = makeRng(seed);
+ const poolSize = PAIRS + 1; // index 0 reserved as decoy
+ const poolU32 = buildPointPool(poolSize, R, p, rng);
+ const { csrIndices, counts } = buildSyntheticCSR(PAIRS, BUCKETS, rng);
+ // Recompute offsets locally (buildSyntheticCSR returns them too, but
+ // we only need it inside buildChunkPlan).
+ const offsets = new Uint32Array(BUCKETS + 1);
+ for (let b = 0; b < BUCKETS; b++) offsets[b + 1] = offsets[b] + counts[b];
+
+ const { chunkPlan, T, tailPoints } = buildChunkPlan(offsets, counts, s);
+ const density = (T * s) / PAIRS;
+ const numWgs = Math.ceil(T / WGI);
+ log(
+ 'info',
+ `chunk plan: T=${T} chunks (S=${s} each) -> ${T * s} pts (${(density * 100).toFixed(1)}% of ${PAIRS}); tail=${tailPoints} pts skipped`,
+ );
+
+ if (T === 0) {
+ throw new Error(`S=${s}: no dense chunks (every bucket has count<${s}). Try smaller S or fewer buckets.`);
+ }
+
+ // GPU buffers.
+ const mkSb = (size: number, copyDst: boolean, copySrc: boolean): GPUBuffer => {
+ let usage = GPUBufferUsage.STORAGE;
+ if (copyDst) usage |= GPUBufferUsage.COPY_DST;
+ if (copySrc) usage |= GPUBufferUsage.COPY_SRC;
+ return device.createBuffer({ size, usage });
+ };
+
+ const poolBuf = mkSb(poolU32.byteLength, true, false);
+ const csrBuf = mkSb(csrIndices.byteLength, true, false);
+ const chunkBuf = mkSb(chunkPlan.byteLength, true, false);
+ const chainBytes = 4 * PG * (T * s) * 4 * 4; // 4 planes * PG vec4/elem * (T*S) elems * 4 u32/vec4 * 4 B/u32
+ const chainBuf = mkSb(chainBytes, false, true);
+ const marshalParams = device.createBuffer({
+ size: 16,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
+ });
+ const chainParams = device.createBuffer({
+ size: 16,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
+ });
+
+ device.queue.writeBuffer(poolBuf, 0, poolU32);
+ device.queue.writeBuffer(csrBuf, 0, csrIndices);
+ device.queue.writeBuffer(chunkBuf, 0, chunkPlan);
+ // marshal: params = [T, poolSize, _, _]
+ device.queue.writeBuffer(marshalParams, 0, new Uint32Array([T, poolSize, 0, 0]));
+ // chain: params = [N_chain, T, _, _] with N_chain = T*S
+ device.queue.writeBuffer(chainParams, 0, new Uint32Array([T * s, T, 0, 0]));
+
+ // Compile marshal pipeline.
+ const marshalCode = sm.gen_ba_marshal_chain_shader(WGI, s);
+ log('info', `marshal shader ${marshalCode.length} chars`);
+ const mLayout = marshalLayout(device);
+ const marshalPipeline = await compile(device, marshalCode, `marshal-W${WGI}-S${s}`, mLayout);
+ const marshalBind = device.createBindGroup({
+ layout: mLayout,
+ entries: [
+ { binding: 0, resource: { buffer: csrBuf } },
+ { binding: 1, resource: { buffer: chunkBuf } },
+ { binding: 2, resource: { buffer: poolBuf } },
+ { binding: 3, resource: { buffer: chainBuf } },
+ { binding: 4, resource: { buffer: marshalParams } },
+ ],
+ });
+
+ // Compile chain pipeline.
+ const chainCode = sm.gen_ba_rev_packed_carry_bench_shader(WGI, s);
+ log('info', `chain shader ${chainCode.length} chars`);
+ const cLayout = chainLayout(device);
+ const chainPipeline = await compile(device, chainCode, `chain-W${WGI}-S${s}`, cLayout);
+ const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE });
+ // Chain output buffer (separate from input chain_buf since the kernel
+ // writes its R outputs to a 2-plane output buffer).
+ const chainOutBytes = 2 * PG * (T * s) * 4 * 4;
+ const chainOutBuf = mkSb(chainOutBytes, false, true);
+ const chainBind = device.createBindGroup({
+ layout: cLayout,
+ entries: [
+ { binding: 0, resource: { buffer: chainBuf } },
+ { binding: 1, resource: { buffer: dummy } },
+ { binding: 2, resource: { buffer: chainOutBuf } },
+ { binding: 3, resource: { buffer: chainParams } },
+ ],
+ });
+
+ log('info', `marshal: numWgs=${numWgs}, ${T} threads, ${T * s} pts gathered/dispatch`);
+ log('info', `chain : numWgs=${numWgs}, ${T} threads, S=${s} adds/thread = ${T * s} pts processed/dispatch`);
+
+ // Marshal must run at least once before chain (chain reads chain_buf).
+ // Warmup is built into timeDispatchPasses.
+ const marshalSamples = await timeDispatchPasses(device, marshalPipeline, marshalBind, numWgs, reps, DISP);
+ const chainSamples = await timeDispatchPasses(device, chainPipeline, chainBind, numWgs, reps, DISP);
+
+ const sanityOk = await readNonZero(device, chainOutBuf, 8);
+
+ const marshalMed = median(marshalSamples);
+ const chainMed = median(chainSamples);
+ const ptsPerSample = T * s * DISP;
+ const marshalNsPerPt = (marshalMed * 1e6) / ptsPerSample;
+ const chainNsPerPt = (chainMed * 1e6) / ptsPerSample;
+ const combinedNsPerPt = marshalNsPerPt + chainNsPerPt;
+
+ log(
+ sanityOk ? 'ok' : 'err',
+ `S=${s}: marshal=${marshalNsPerPt.toFixed(2)}ns/pt chain=${chainNsPerPt.toFixed(2)}ns/pt combined=${combinedNsPerPt.toFixed(2)}ns/pt density=${(density * 100).toFixed(1)}% sanity=${sanityOk ? 'OK' : 'FAIL'}`,
+ );
+
+ poolBuf.destroy();
+ csrBuf.destroy();
+ chunkBuf.destroy();
+ chainBuf.destroy();
+ chainOutBuf.destroy();
+ dummy.destroy();
+ marshalParams.destroy();
+ chainParams.destroy();
+
+ return {
+ s,
+ wgi: WGI,
+ pairs: PAIRS,
+ buckets: BUCKETS,
+ T,
+ tail_points: tailPoints,
+ density,
+ disp: DISP,
+ marshal_median_ms: marshalMed,
+ marshal_min_ms: Math.min(...marshalSamples),
+ marshal_max_ms: Math.max(...marshalSamples),
+ marshal_ns_per_pt: marshalNsPerPt,
+ chain_median_ms: chainMed,
+ chain_min_ms: Math.min(...chainSamples),
+ chain_max_ms: Math.max(...chainSamples),
+ chain_ns_per_pt: chainNsPerPt,
+ combined_ns_per_pt: combinedNsPerPt,
+ marshal_samples_ms: marshalSamples,
+ chain_samples_ms: chainSamples,
+ sanity_ok: sanityOk,
+ };
+}
+
+function parseParams() {
+ const qp = new URLSearchParams(window.location.search);
+ const reps = parseInt(qp.get('reps') ?? '5', 10);
+ if (!Number.isFinite(reps) || reps <= 0 || reps > 50) {
+ throw new Error(`?reps must be in (0, 50], got ${qp.get('reps')}`);
+ }
+ const pairsStr = qp.get('pairs');
+ if (pairsStr !== null) {
+ const v = parseInt(pairsStr, 10);
+ if (!Number.isFinite(v) || v <= 0 || v > (1 << 20)) {
+ throw new Error(`?pairs must be in (0, 2^20], got ${pairsStr}`);
+ }
+ PAIRS = v;
+ }
+ const bucketsStr = qp.get('buckets');
+ if (bucketsStr !== null) {
+ const v = parseInt(bucketsStr, 10);
+ if (!Number.isFinite(v) || v <= 0 || v > (1 << 18)) {
+ throw new Error(`?buckets must be in (0, 2^18], got ${bucketsStr}`);
+ }
+ BUCKETS = v;
+ }
+ const wgiStr = qp.get('wgi');
+ if (wgiStr !== null) {
+ const v = parseInt(wgiStr, 10);
+ if (!Number.isFinite(v) || v <= 0 || v > 1024) {
+ throw new Error(`?wgi must be in (0, 1024], got ${wgiStr}`);
+ }
+ WGI = v;
+ }
+ const dispStr = qp.get('disp');
+ if (dispStr !== null) {
+ const v = parseInt(dispStr, 10);
+ if (!Number.isFinite(v) || v <= 0 || v > 64) {
+ throw new Error(`?disp must be in (0, 64], got ${dispStr}`);
+ }
+ DISP = v;
+ }
+ const sStr = qp.get('s');
+ if (sStr !== null) {
+ const list = sStr.split(',').map(v => parseInt(v, 10));
+ for (const v of list) {
+ if (!Number.isFinite(v) || v <= 0 || v > 256) {
+ throw new Error(`?s entries must be in (0, 256], got ${v}`);
+ }
+ }
+ S_SWEEP = list;
+ }
+ return { reps, pairs: PAIRS, buckets: BUCKETS, wgi: WGI, disp: DISP, s_sweep: S_SWEEP };
+}
+
+async function main() {
+ try {
+ if (!('gpu' in navigator)) {
+ throw new Error('navigator.gpu missing — WebGPU not available');
+ }
+ const params = parseParams();
+ benchState.params = params;
+ log(
+ 'info',
+ `params: reps=${params.reps} pairs=${params.pairs} buckets=${params.buckets} wgi=${params.wgi} disp=${params.disp} s=[${params.s_sweep.join(',')}]`,
+ );
+
+ benchState.state = 'running';
+ const device = await get_device();
+ log('info', 'WebGPU device acquired');
+
+ const p = BN254_BASE_FIELD;
+ const miscParams = compute_misc_params(p, 13);
+ const R = miscParams.r;
+
+ const sm = new ShaderManager(4, PAIRS, BN254_CURVE_CONFIG, false);
+
+ let seed = 0xc511;
+ for (const s of S_SWEEP) {
+ try {
+ const r = await runOne(device, sm, s, params.reps, R, p, seed);
+ benchState.results.push(r);
+ resultsClient.postProgress({
+ kind: 'batch_done',
+ s,
+ marshal_ns_per_pt: r.marshal_ns_per_pt,
+ chain_ns_per_pt: r.chain_ns_per_pt,
+ combined_ns_per_pt: r.combined_ns_per_pt,
+ density: r.density,
+ sanity_ok: r.sanity_ok,
+ });
+ seed += 0x10;
+ } catch (e) {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `S=${s} failed: ${msg} — STOPPING sweep at first failure`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ return;
+ }
+ }
+
+ benchState.state = 'done';
+ log('ok', 'all sizes done');
+ } catch (e) {
+ const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e);
+ log('err', `FATAL: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ }
+}
+
+main()
+ .catch(e => {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `unhandled: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ })
+ .finally(() => {
+ postFinal().catch(() => {});
+ });
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.html
new file mode 100644
index 000000000000..5126b5f73cc0
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.html
@@ -0,0 +1,22 @@
+
+
+
+
+ v2 prod orchestrator + indirect dispatch noble oracle (WebGPU)
+
+
+
+ v2 prod orchestrator + indirect dispatch noble oracle
+ Query params: ?subtasks=T&columns=B&input=N&s=S&wgi=W&tpb=TPB&per=PER&lvls=L&seed=K
+
+
+
+
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.ts
new file mode 100644
index 000000000000..10e2be473956
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.ts
@@ -0,0 +1,349 @@
+///
+// End-to-end correctness oracle for the prod v2 pair-tree orchestrator.
+//
+// Wraps runSmvpV2PairTree (cuzk/smvp_v2_pair_tree.ts) with a noble-CPU
+// cross-check on real BN254 affine points. Validates the full prod
+// path: csr_to_v2_meta + csr_to_v2_active_sums + planner_v2_prod +
+// marshal_prod + disjoint_prod + scatter_prod + carry_prod +
+// v2_to_running, with indirect dispatch driven by the planner's
+// per-level totals.
+//
+// Sizing: small by design (num_subtasks=1, num_columns=32, N=256) so
+// noble's projective bucket sum runs instantly in the browser. Each
+// bucket gets ~8 points; pair-tree needs ~4 levels.
+
+import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js';
+import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js';
+import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js';
+import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js';
+import { BN254_BASE_FIELD, modInverse } from '../../src/msm_webgpu/cuzk/bn254.js';
+import { runSmvpV2PairTree } from '../../src/msm_webgpu/cuzk/smvp_v2_pair_tree.js';
+import { makeResultsClient } from './results_post.js';
+import { bn254 } from '@noble/curves/bn254';
+
+let NUM_SUBTASKS = 1;
+let NUM_COLUMNS = 32;
+let INPUT_SIZE = 256;
+let S = 16;
+let WGI = 64;
+let TPB = 64;
+let PER_THREAD = 1;
+let MAX_LEVELS = 8;
+let SEED = 0xc0de;
+
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+function bigintToPackedU32x8(v: bigint): Uint32Array {
+ const w = new Uint32Array(8);
+ let x = v;
+ for (let i = 0; i < 8; i++) {
+ w[i] = Number(x & 0xffffffffn);
+ x >>= 32n;
+ }
+ return w;
+}
+
+function packedU32x8ToBigint(w: Uint32Array, off: number): bigint {
+ let v = 0n;
+ for (let i = 7; i >= 0; i--) v = (v << 32n) | BigInt(w[off + i] >>> 0);
+ return v;
+}
+
+interface OracleResult {
+ num_subtasks: number;
+ num_columns: number;
+ input_size: number;
+ s: number;
+ wgi: number;
+ tpb: number;
+ per_thread: number;
+ max_levels: number;
+ total_passes: number;
+ gpu_wall_ms: number;
+ buckets_checked: number;
+ buckets_passed: number;
+ first_mismatches: Array<{ subtask: number; bucket: number; count: number; gpu_x: string; gpu_y: string; ref_x: string; ref_y: string; ok: boolean }>;
+ all_passed: boolean;
+}
+
+interface BenchState {
+ state: 'boot' | 'running' | 'done' | 'error';
+ params: Record | null;
+ results: OracleResult[];
+ error: string | null;
+ log: string[];
+}
+
+const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] };
+(window as unknown as { __bench: BenchState }).__bench = benchState;
+const resultsClient = makeResultsClient({ page: 'bench-msm-oracle-prod' });
+(window as unknown as { __runId: string }).__runId = resultsClient.runId;
+
+async function postFinal(): Promise {
+ await resultsClient.postResults({
+ state: benchState.state,
+ params: benchState.params,
+ results: benchState.results,
+ error: benchState.error,
+ log: benchState.log,
+ userAgent: navigator.userAgent,
+ hardwareConcurrency: navigator.hardwareConcurrency,
+ });
+}
+
+const $log = document.getElementById('log') as HTMLDivElement;
+function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) {
+ const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : '';
+ const span = document.createElement('div');
+ span.className = cls;
+ span.textContent = msg;
+ $log.appendChild(span);
+ benchState.log.push(`[${level}] ${msg}`);
+ console.log(`[bench-msm-oracle-prod] ${msg}`);
+}
+
+async function readbackU32(device: GPUDevice, buf: GPUBuffer, bytes: number): Promise {
+ const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST });
+ const enc = device.createCommandEncoder();
+ enc.copyBufferToBuffer(buf, 0, staging, 0, bytes);
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await staging.mapAsync(GPUMapMode.READ);
+ const out = new Uint32Array(staging.getMappedRange().slice(0));
+ staging.unmap();
+ staging.destroy();
+ return out;
+}
+
+async function runOracle(device: GPUDevice, sm: ShaderManager, R: bigint, Rinv: bigint, p: bigint): Promise {
+ log('info', `=== T=${NUM_SUBTASKS} B=${NUM_COLUMNS} N=${INPUT_SIZE} S=${S} WGI=${WGI} TPB=${TPB} PER=${PER_THREAD} MAX_LEVELS=${MAX_LEVELS}`);
+ const rng = makeRng(SEED);
+ const G1 = bn254.G1.ProjectivePoint;
+ const order = bn254.fields.Fr.ORDER;
+
+ const points: Array<{ x: bigint; y: bigint }> = [];
+ const pointXWords = new Uint32Array(INPUT_SIZE * 8);
+ const pointYWords = new Uint32Array(INPUT_SIZE * 8);
+ for (let i = 0; i < INPUT_SIZE; i++) {
+ let k = 0n;
+ for (let w = 0; w < 8; w++) k = (k << 32n) | BigInt(rng() >>> 0);
+ k = k % order;
+ if (k === 0n) k = 1n;
+ const aff = G1.BASE.multiply(k).toAffine();
+ points.push({ x: aff.x, y: aff.y });
+ pointXWords.set(bigintToPackedU32x8((aff.x * R) % p), 8 * i);
+ pointYWords.set(bigintToPackedU32x8((aff.y * R) % p), 8 * i);
+ }
+ log('info', `generated ${INPUT_SIZE} BN254 affine points`);
+
+ const valIdxArr = new Uint32Array(NUM_SUBTASKS * INPUT_SIZE);
+ const rowPtrArr = new Uint32Array(NUM_SUBTASKS * (NUM_COLUMNS + 1));
+ const bucketOf: Uint32Array[] = [];
+ for (let st = 0; st < NUM_SUBTASKS; st++) {
+ const bucket = new Uint32Array(INPUT_SIZE);
+ const counts = new Uint32Array(NUM_COLUMNS);
+ for (let i = 0; i < INPUT_SIZE; i++) {
+ const hi = (rng() >>> 16) & 0xffff;
+ const lo = (rng() >>> 16) & 0xffff;
+ const v = hi * 0x10000 + lo;
+ const b = v % NUM_COLUMNS;
+ bucket[i] = b;
+ counts[b]++;
+ }
+ bucketOf.push(bucket);
+ const offsets = new Uint32Array(NUM_COLUMNS + 1);
+ for (let b = 0; b < NUM_COLUMNS; b++) offsets[b + 1] = offsets[b] + counts[b];
+ const cursor = new Uint32Array(NUM_COLUMNS);
+ for (let i = 0; i < INPUT_SIZE; i++) {
+ const b = bucket[i];
+ const slot = offsets[b] + cursor[b]++;
+ valIdxArr[st * INPUT_SIZE + slot] = i;
+ }
+ const rpBase = st * (NUM_COLUMNS + 1);
+ for (let b = 0; b <= NUM_COLUMNS; b++) rowPtrArr[rpBase + b] = offsets[b];
+ }
+ log('info', `built synthetic CSR for ${NUM_SUBTASKS} window(s)`);
+
+ const mk = (bytes: number, extra: GPUBufferUsageFlags = 0): GPUBuffer =>
+ device.createBuffer({ size: bytes, usage: GPUBufferUsage.STORAGE | extra });
+ const val_idx_buf = mk(valIdxArr.byteLength, GPUBufferUsage.COPY_DST);
+ const row_ptr_buf = mk(rowPtrArr.byteLength, GPUBufferUsage.COPY_DST);
+ const point_x_buf = mk(INPUT_SIZE * 32, GPUBufferUsage.COPY_DST);
+ const point_y_buf = mk(INPUT_SIZE * 32, GPUBufferUsage.COPY_DST);
+ const running_x_buf = mk(NUM_SUBTASKS * NUM_COLUMNS * 32, GPUBufferUsage.COPY_SRC);
+ const running_y_buf = mk(NUM_SUBTASKS * NUM_COLUMNS * 32, GPUBufferUsage.COPY_SRC);
+ const bucket_active_buf = mk(NUM_SUBTASKS * NUM_COLUMNS * 4, GPUBufferUsage.COPY_SRC);
+
+ device.queue.writeBuffer(val_idx_buf, 0, valIdxArr as BufferSource);
+ device.queue.writeBuffer(row_ptr_buf, 0, rowPtrArr as BufferSource);
+ device.queue.writeBuffer(point_x_buf, 0, pointXWords as BufferSource);
+ device.queue.writeBuffer(point_y_buf, 0, pointYWords as BufferSource);
+
+ const stats = await runSmvpV2PairTree({
+ device,
+ shaderManager: sm,
+ num_subtasks: NUM_SUBTASKS,
+ num_columns: NUM_COLUMNS,
+ input_size: INPUT_SIZE,
+ s: S,
+ tpb: TPB,
+ per_thread: PER_THREAD,
+ wgi: WGI,
+ max_levels: MAX_LEVELS,
+ val_idx_buf, row_ptr_buf, point_x_buf, point_y_buf,
+ running_x_buf, running_y_buf, bucket_active_buf,
+ });
+ log('info', `runSmvpV2PairTree returned: ${JSON.stringify(stats)}`);
+
+ const runningXWords = await readbackU32(device, running_x_buf, NUM_SUBTASKS * NUM_COLUMNS * 32);
+ const runningYWords = await readbackU32(device, running_y_buf, NUM_SUBTASKS * NUM_COLUMNS * 32);
+ const bucketActive = await readbackU32(device, bucket_active_buf, NUM_SUBTASKS * NUM_COLUMNS * 4);
+
+ const refSumPerBucket = new Map();
+ for (let st = 0; st < NUM_SUBTASKS; st++) {
+ const bucket = bucketOf[st];
+ for (let b = 0; b < NUM_COLUMNS; b++) {
+ let acc = G1.ZERO;
+ let count = 0;
+ for (let i = 0; i < INPUT_SIZE; i++) {
+ if (bucket[i] !== b) continue;
+ acc = acc.add(G1.fromAffine({ x: points[i].x, y: points[i].y }));
+ count++;
+ }
+ const bucket_global = st * NUM_COLUMNS + b;
+ refSumPerBucket.set(bucket_global, count === 0 ? null : (acc.is0() ? null : acc.toAffine()));
+ }
+ }
+
+ const checks: OracleResult['first_mismatches'] = [];
+ const mismatches: OracleResult['first_mismatches'] = [];
+ let passCount = 0;
+ for (let st = 0; st < NUM_SUBTASKS; st++) {
+ const bucket = bucketOf[st];
+ const counts = new Uint32Array(NUM_COLUMNS);
+ for (let i = 0; i < INPUT_SIZE; i++) counts[bucket[i]]++;
+ for (let b = 0; b < NUM_COLUMNS; b++) {
+ const bucket_global = st * NUM_COLUMNS + b;
+ const ref = refSumPerBucket.get(bucket_global) ?? null;
+ const active = bucketActive[bucket_global];
+ if (counts[b] === 0) {
+ if (active !== 0) mismatches.push({ subtask: st, bucket: b, count: 0, gpu_x: 'active=1', gpu_y: '', ref_x: 'empty', ref_y: '', ok: false });
+ continue;
+ }
+ if (active === 0) {
+ mismatches.push({ subtask: st, bucket: b, count: counts[b], gpu_x: 'active=0', gpu_y: '', ref_x: ref ? ref.x.toString(16) : 'INF', ref_y: ref ? ref.y.toString(16) : 'INF', ok: false });
+ continue;
+ }
+ const xMont = packedU32x8ToBigint(runningXWords, bucket_global * 8);
+ const yMont = packedU32x8ToBigint(runningYWords, bucket_global * 8);
+ const gx = (xMont * Rinv) % p;
+ const gy = (yMont * Rinv) % p;
+ const ok = ref !== null && gx === ref.x && gy === ref.y;
+ const entry = { subtask: st, bucket: b, count: counts[b], gpu_x: gx.toString(16), gpu_y: gy.toString(16), ref_x: ref ? ref.x.toString(16) : 'INF', ref_y: ref ? ref.y.toString(16) : 'INF', ok };
+ checks.push(entry);
+ if (ok) passCount++;
+ else if (mismatches.length < 8) mismatches.push(entry);
+ }
+ }
+ const allPassed = mismatches.length === 0;
+
+ if (allPassed) {
+ log('ok', `oracle PASS — ${passCount}/${checks.length} buckets match noble reference (prod orchestrator)`);
+ } else {
+ log('err', `oracle FAIL — ${mismatches.length} mismatches in first ${checks.length} buckets`);
+ for (const m of mismatches.slice(0, 8)) {
+ log('err', ` subtask=${m.subtask} bucket=${m.bucket} count=${m.count}`);
+ log('err', ` gpu: x=${m.gpu_x} y=${m.gpu_y}`);
+ log('err', ` ref: x=${m.ref_x} y=${m.ref_y}`);
+ }
+ }
+
+ val_idx_buf.destroy();
+ row_ptr_buf.destroy();
+ point_x_buf.destroy();
+ point_y_buf.destroy();
+ running_x_buf.destroy();
+ running_y_buf.destroy();
+ bucket_active_buf.destroy();
+
+ return {
+ num_subtasks: NUM_SUBTASKS,
+ num_columns: NUM_COLUMNS,
+ input_size: INPUT_SIZE,
+ s: S,
+ wgi: WGI,
+ tpb: TPB,
+ per_thread: PER_THREAD,
+ max_levels: MAX_LEVELS,
+ total_passes: stats.total_passes,
+ gpu_wall_ms: stats.gpu_wall_ms,
+ buckets_checked: checks.length,
+ buckets_passed: passCount,
+ first_mismatches: mismatches,
+ all_passed: allPassed,
+ };
+}
+
+function parseParams() {
+ const qp = new URLSearchParams(window.location.search);
+ if (qp.get('subtasks')) NUM_SUBTASKS = parseInt(qp.get('subtasks')!, 10);
+ if (qp.get('columns')) NUM_COLUMNS = parseInt(qp.get('columns')!, 10);
+ if (qp.get('input')) INPUT_SIZE = parseInt(qp.get('input')!, 10);
+ if (qp.get('s')) S = parseInt(qp.get('s')!, 10);
+ if (qp.get('wgi')) WGI = parseInt(qp.get('wgi')!, 10);
+ if (qp.get('tpb')) TPB = parseInt(qp.get('tpb')!, 10);
+ if (qp.get('per')) PER_THREAD = parseInt(qp.get('per')!, 10);
+ if (qp.get('lvls')) MAX_LEVELS = parseInt(qp.get('lvls')!, 10);
+ if (qp.get('seed')) SEED = parseInt(qp.get('seed')!, 10);
+ return { subtasks: NUM_SUBTASKS, columns: NUM_COLUMNS, input: INPUT_SIZE, s: S, wgi: WGI, tpb: TPB, per_thread: PER_THREAD, max_levels: MAX_LEVELS, seed: SEED };
+}
+
+async function main() {
+ try {
+ if (!('gpu' in navigator)) throw new Error('navigator.gpu missing');
+ const params = parseParams();
+ benchState.params = params;
+ log('info', `params: ${JSON.stringify(params)}`);
+ benchState.state = 'running';
+ const device = await get_device();
+ log('info', 'WebGPU device acquired');
+ const p = BN254_BASE_FIELD;
+ const miscParams = compute_misc_params(p, 13);
+ const R = miscParams.r;
+ const Rinv = modInverse(R, p);
+ const sm = new ShaderManager(NUM_SUBTASKS, NUM_COLUMNS, BN254_CURVE_CONFIG, false);
+ const r = await runOracle(device, sm, R, Rinv, p);
+ benchState.results.push(r);
+ resultsClient.postProgress({
+ kind: 'oracle_prod_done',
+ all_passed: r.all_passed,
+ buckets_passed: r.buckets_passed,
+ buckets_checked: r.buckets_checked,
+ gpu_wall_ms: r.gpu_wall_ms,
+ });
+ benchState.state = 'done';
+ log('ok', 'done');
+ } catch (e) {
+ const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e);
+ log('err', `FATAL: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ }
+}
+
+main()
+ .catch(e => {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `unhandled: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ })
+ .finally(() => {
+ postFinal().catch(() => {});
+ });
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.html
new file mode 100644
index 000000000000..8296a4e0844b
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.html
@@ -0,0 +1,22 @@
+
+
+
+
+ v2 pair-tree bucket-accumulate noble-CPU oracle (WebGPU)
+
+
+
+ v2 pair-tree bucket-accumulate noble-CPU oracle
+ Query params: ?n=N&buckets=B&s=S&wgi=W&seed=K
+
+
+
+
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.ts
new file mode 100644
index 000000000000..2addcec51fae
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.ts
@@ -0,0 +1,625 @@
+///
+// End-to-end correctness oracle for the v2 bin-packed pair-tree
+// bucket-accumulate pipeline. Feeds REAL BN254 affine points (random
+// scalar * G, via @noble/curves) into the pair-tree and verifies that
+// the per-bucket reduced sum matches a noble-projective reference.
+//
+// This is the test fused_revcarry never had: a ground-truth oracle on
+// real curve data. If this passes, the v2 pair-tree's bucket-accumulate
+// math (disjoint pair-sum + suffix-product single-fr_inv + lean affine
+// add) is correct end-to-end on real BN254 points.
+//
+// Scope: validates the round kernel + planner only. BPR / horner /
+// finalize are NOT part of v2 yet — they're step 3+ of the rewrite plan.
+// This oracle stops at "per-bucket sum is correct".
+//
+// Sizing: tiny by design — N=256 points, B=32 buckets, single window
+// (no signed slicing). Each bucket gets ~8 points; the pair-tree
+// reduces in ~4 levels.
+
+import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js';
+import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js';
+import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js';
+import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js';
+import { BN254_BASE_FIELD, modInverse } from '../../src/msm_webgpu/cuzk/bn254.js';
+import { makeResultsClient } from './results_post.js';
+import { bn254 } from '@noble/curves/bn254';
+
+const PG = 2;
+let NPTS = 256;
+let BUCKETS = 32;
+let S = 16;
+let WGI = 64;
+let SEED = 0xa110ce;
+
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+function bigintToPackedU32x8(v: bigint): Uint32Array {
+ const w = new Uint32Array(8);
+ let x = v;
+ for (let i = 0; i < 8; i++) {
+ w[i] = Number(x & 0xffffffffn);
+ x >>= 32n;
+ }
+ return w;
+}
+
+function packedU32x8ToBigint(w: Uint32Array, off: number): bigint {
+ let v = 0n;
+ for (let i = 7; i >= 0; i--) v = (v << 32n) | BigInt(w[off + i] >>> 0);
+ return v;
+}
+
+function makeSoABuf(device: GPUDevice, M: number, copyDst: boolean, copySrc: boolean): GPUBuffer {
+ const bytes = 2 * PG * M * 4 * 4;
+ let usage = GPUBufferUsage.STORAGE;
+ if (copyDst) usage |= GPUBufferUsage.COPY_DST;
+ if (copySrc) usage |= GPUBufferUsage.COPY_SRC;
+ return device.createBuffer({ size: bytes, usage });
+}
+
+interface CurvePoint {
+ x: bigint;
+ y: bigint;
+}
+
+function buildL0WithRealPoints(
+ N: number,
+ B: number,
+ R: bigint,
+ p: bigint,
+ rng: () => number,
+): {
+ initBuf: Uint32Array;
+ initCounts: Uint32Array;
+ initOffsets: Uint32Array;
+ M: number;
+ points: CurvePoint[];
+ bucket: Uint32Array;
+} {
+ const M = N + 2;
+ const buf = new Uint32Array(2 * PG * M * 4);
+ const G1 = bn254.G1.ProjectivePoint;
+ const order = bn254.fields.Fr.ORDER;
+
+ const points: CurvePoint[] = [];
+ const xWords = new Uint32Array(8 * M);
+ const yWords = new Uint32Array(8 * M);
+
+ for (let i = 0; i < N; i++) {
+ let k = 0n;
+ for (let w = 0; w < 8; w++) k = (k << 32n) | BigInt(rng() >>> 0);
+ k = k % order;
+ if (k === 0n) k = 1n;
+ const aff = G1.BASE.multiply(k).toAffine();
+ points.push({ x: aff.x, y: aff.y });
+ const xMont = (aff.x * R) % p;
+ const yMont = (aff.y * R) % p;
+ xWords.set(bigintToPackedU32x8(xMont), 8 * i);
+ yWords.set(bigintToPackedU32x8(yMont), 8 * i);
+ }
+
+ for (let pad = 0; pad < 2; pad++) {
+ const i = N + pad;
+ let xCand: bigint;
+ do {
+ xCand = 0n;
+ for (let w = 0; w < 8; w++) xCand = (xCand << 32n) | BigInt(rng() >>> 0);
+ xCand = xCand % p;
+ } while (xCand === 0n);
+ const yCand = ((xCand + 1n + BigInt(pad)) * 7n) % p;
+ const xMont = (xCand * R) % p;
+ const yMont = (yCand * R) % p;
+ xWords.set(bigintToPackedU32x8(xMont), 8 * i);
+ yWords.set(bigintToPackedU32x8(yMont), 8 * i);
+ }
+
+ const bucket = new Uint32Array(N);
+ const counts = new Uint32Array(B);
+ for (let i = 0; i < N; i++) {
+ const hi = (rng() >>> 16) & 0xffff;
+ const lo = (rng() >>> 16) & 0xffff;
+ const v = hi * 0x10000 + lo;
+ const b = v % B;
+ bucket[i] = b;
+ counts[b]++;
+ }
+ const offsets = new Uint32Array(B + 1);
+ for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b];
+
+ const cursor = new Uint32Array(B);
+ const writeElem = (planeIdx: number, dstIdx: number, words: Uint32Array, srcOff: number) => {
+ const planeBase = planeIdx * PG * M;
+ for (let v = 0; v < PG; v++) {
+ const base = (planeBase + PG * dstIdx + v) * 4;
+ buf[base + 0] = words[srcOff + 4 * v + 0];
+ buf[base + 1] = words[srcOff + 4 * v + 1];
+ buf[base + 2] = words[srcOff + 4 * v + 2];
+ buf[base + 3] = words[srcOff + 4 * v + 3];
+ }
+ };
+ for (let i = 0; i < N; i++) {
+ const b = bucket[i];
+ const dst = offsets[b] + cursor[b]++;
+ writeElem(0, dst, xWords, 8 * i);
+ writeElem(1, dst, yWords, 8 * i);
+ }
+ writeElem(0, M - 2, xWords, 8 * (M - 2));
+ writeElem(1, M - 2, yWords, 8 * (M - 2));
+ writeElem(0, M - 1, xWords, 8 * (M - 1));
+ writeElem(1, M - 1, yWords, 8 * (M - 1));
+
+ return { initBuf: buf, initCounts: counts, initOffsets: offsets, M, points, bucket };
+}
+
+function buildLevelPlan(
+ counts: Uint32Array,
+ offsets: Uint32Array,
+ s: number,
+ padLIdx: number,
+ padRIdx: number,
+ discardIdx: number,
+) {
+ const B = counts.length;
+ let totalPairs = 0;
+ let totalCarries = 0;
+ const newCounts = new Uint32Array(B);
+ for (let b = 0; b < B; b++) {
+ const n = counts[b];
+ const p = Math.floor(n / 2);
+ const c = n & 1;
+ totalPairs += p;
+ totalCarries += c;
+ newCounts[b] = p + c;
+ }
+ const newOffsets = new Uint32Array(B + 1);
+ for (let b = 0; b < B; b++) newOffsets[b + 1] = newOffsets[b] + newCounts[b];
+
+ const numChunks = Math.max(1, Math.ceil(totalPairs / s));
+ const chunkPlan = new Uint32Array(2 * s * numChunks);
+ const scatterPlan = new Uint32Array(s * numChunks);
+ const carryPlan = new Uint32Array(2 * Math.max(1, totalCarries));
+
+ for (let i = 0; i < numChunks * s; i++) {
+ chunkPlan[2 * i + 0] = padLIdx;
+ chunkPlan[2 * i + 1] = padRIdx;
+ scatterPlan[i] = discardIdx;
+ }
+
+ let slot = 0;
+ let carryIdx = 0;
+ for (let b = 0; b < B; b++) {
+ const n = counts[b];
+ const p = Math.floor(n / 2);
+ for (let j = 0; j < p; j++) {
+ chunkPlan[2 * slot + 0] = offsets[b] + 2 * j;
+ chunkPlan[2 * slot + 1] = offsets[b] + 2 * j + 1;
+ scatterPlan[slot] = newOffsets[b] + j;
+ slot++;
+ }
+ if (n & 1) {
+ carryPlan[2 * carryIdx + 0] = offsets[b] + n - 1;
+ carryPlan[2 * carryIdx + 1] = newOffsets[b] + p;
+ carryIdx++;
+ }
+ }
+ return { chunkPlan, scatterPlan, carryPlan, newCounts, newOffsets, numChunks, numCarries: totalCarries, totalPairs };
+}
+
+interface BucketCheck {
+ bucket: number;
+ count: number;
+ gpu_x: string;
+ gpu_y: string;
+ ref_x: string;
+ ref_y: string;
+ ok: boolean;
+}
+
+interface OracleResult {
+ n: number;
+ buckets: number;
+ s: number;
+ wgi: number;
+ levels: number;
+ total_pair_adds: number;
+ buckets_checked: number;
+ buckets_passed: number;
+ first_mismatches: BucketCheck[];
+ all_passed: boolean;
+ gpu_wall_ms: number;
+}
+
+interface BenchState {
+ state: 'boot' | 'running' | 'done' | 'error';
+ params: Record | null;
+ results: OracleResult[];
+ error: string | null;
+ log: string[];
+}
+
+const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] };
+(window as unknown as { __bench: BenchState }).__bench = benchState;
+const resultsClient = makeResultsClient({ page: 'bench-msm-oracle' });
+(window as unknown as { __runId: string }).__runId = resultsClient.runId;
+
+async function postFinal(): Promise {
+ await resultsClient.postResults({
+ state: benchState.state,
+ params: benchState.params,
+ results: benchState.results,
+ error: benchState.error,
+ log: benchState.log,
+ userAgent: navigator.userAgent,
+ hardwareConcurrency: navigator.hardwareConcurrency,
+ });
+}
+
+const $log = document.getElementById('log') as HTMLDivElement;
+function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) {
+ const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : '';
+ const span = document.createElement('div');
+ span.className = cls;
+ span.textContent = msg;
+ $log.appendChild(span);
+ benchState.log.push(`[${level}] ${msg}`);
+ console.log(`[bench-msm-oracle] ${msg}`);
+}
+
+async function compileOne(device: GPUDevice, code: string, key: string, layout: GPUBindGroupLayout): Promise {
+ const module = device.createShaderModule({ code });
+ const info = await module.getCompilationInfo();
+ let hasError = false;
+ const errLines: string[] = [];
+ for (const m of info.messages) {
+ const line = `[shader ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`;
+ if (m.type === 'error') {
+ console.error(line);
+ log('err', line);
+ errLines.push(line);
+ hasError = true;
+ } else {
+ console.warn(line);
+ }
+ }
+ if (hasError) throw new Error(`WGSL compile failed for ${key}: ${errLines.slice(0, 4).join(' | ')}`);
+ return device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+}
+
+function ioLayout4(device: GPUDevice): GPUBindGroupLayout {
+ return device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+}
+
+async function readbackU32(device: GPUDevice, buf: GPUBuffer, bytes: number): Promise {
+ const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST });
+ const enc = device.createCommandEncoder();
+ enc.copyBufferToBuffer(buf, 0, staging, 0, bytes);
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await staging.mapAsync(GPUMapMode.READ);
+ const out = new Uint32Array(staging.getMappedRange().slice(0));
+ staging.unmap();
+ staging.destroy();
+ return out;
+}
+
+async function runOracle(device: GPUDevice, sm: ShaderManager, R: bigint, Rinv: bigint, p: bigint): Promise {
+ log('info', `=== N=${NPTS} B=${BUCKETS} S=${S} WGI=${WGI}`);
+ const rng = makeRng(SEED);
+ const { initBuf, initCounts, initOffsets, M, points, bucket } = buildL0WithRealPoints(NPTS, BUCKETS, R, p, rng);
+
+ let cMin = NPTS, cMax = 0, cZero = 0;
+ for (let b = 0; b < BUCKETS; b++) {
+ if (initCounts[b] > cMax) cMax = initCounts[b];
+ if (initCounts[b] < cMin) cMin = initCounts[b];
+ if (initCounts[b] === 0) cZero++;
+ }
+ log('info', `built L0: M=${M} bucket counts min=${cMin} max=${cMax} zero=${cZero}/${BUCKETS}`);
+
+ const padLIdx = M - 2;
+ const padRIdx = M - 1;
+ const discardIdx = M - 2;
+
+ const bufA = makeSoABuf(device, M, true, true);
+ const bufB = makeSoABuf(device, M, true, true);
+ device.queue.writeBuffer(bufA, 0, initBuf);
+ device.queue.writeBuffer(bufB, 0, initBuf);
+
+ const maxL0Chunks = Math.ceil(NPTS / 2 / S) + 1;
+ const chainBuf = makeSoABuf(device, 2 * S * maxL0Chunks, false, false);
+ const tempOutBuf = makeSoABuf(device, S * maxL0Chunks, false, true);
+
+ const layoutMarshal = ioLayout4(device);
+ const layoutDisjoint = ioLayout4(device);
+ const layoutScatter = ioLayout4(device);
+ const layoutCarry = ioLayout4(device);
+ const marshalPipe = await compileOne(device, sm.gen_ba_marshal_pairs_bench_shader(WGI, S), `marshal-W${WGI}-S${S}`, layoutMarshal);
+ const disjointPipe = await compileOne(device, sm.gen_ba_pair_disjoint_tree_bench_shader(WGI, S), `disjoint-W${WGI}-S${S}`, layoutDisjoint);
+ const scatterPipe = await compileOne(device, sm.gen_ba_scatter_pairs_bench_shader(WGI, S), `scatter-W${WGI}-S${S}`, layoutScatter);
+ const carryPipe = await compileOne(device, sm.gen_ba_carry_copy_bench_shader(WGI), `carry-W${WGI}`, layoutCarry);
+ log('info', '4 pipelines compiled');
+
+ let counts = initCounts;
+ let offsets = initOffsets;
+ let finalOffsets: Uint32Array = initOffsets;
+ let curIn: GPUBuffer = bufA;
+ let curOut: GPUBuffer = bufB;
+ let totalPairAdds = 0;
+ let levelIdx = 0;
+ const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE });
+
+ interface PassSpec { pipeline: GPUComputePipeline; bind: GPUBindGroup; numWgs: number }
+ const allPasses: PassSpec[] = [];
+ const levelBufHolders: GPUBuffer[] = [];
+
+ for (;;) {
+ let maxCount = 0;
+ for (let b = 0; b < counts.length; b++) if (counts[b] > maxCount) maxCount = counts[b];
+ if (maxCount <= 1) {
+ finalOffsets = offsets;
+ break;
+ }
+ if (levelIdx > 24) throw new Error('exceeded safety level cap');
+
+ const plan = buildLevelPlan(counts, offsets, S, padLIdx, padRIdx, discardIdx);
+ totalPairAdds += plan.totalPairs;
+ const T = plan.numChunks;
+ const numWgs = Math.ceil(T / WGI);
+ log('info', `L${levelIdx}: T=${T} pairs=${plan.totalPairs} carries=${plan.numCarries} maxCount=${maxCount}`);
+
+ const chunkPlanBuf = device.createBuffer({ size: plan.chunkPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST });
+ device.queue.writeBuffer(chunkPlanBuf, 0, plan.chunkPlan);
+ const scatterPlanBuf = device.createBuffer({ size: plan.scatterPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST });
+ device.queue.writeBuffer(scatterPlanBuf, 0, plan.scatterPlan);
+ const carryPlanBuf = device.createBuffer({ size: plan.carryPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST });
+ device.queue.writeBuffer(carryPlanBuf, 0, plan.carryPlan);
+
+ const marshalParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST });
+ device.queue.writeBuffer(marshalParams, 0, new Uint32Array([T, M, 0, 0]));
+ const disjointParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST });
+ device.queue.writeBuffer(disjointParams, 0, new Uint32Array([2 * S * T, T, 1, 0]));
+ const scatterParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST });
+ device.queue.writeBuffer(scatterParams, 0, new Uint32Array([T, M, 0, 0]));
+ const carryParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST });
+ if (plan.numCarries > 0) {
+ device.queue.writeBuffer(carryParams, 0, new Uint32Array([plan.numCarries, M, M, 0]));
+ }
+
+ const marshalBind = device.createBindGroup({
+ layout: layoutMarshal,
+ entries: [
+ { binding: 0, resource: { buffer: chunkPlanBuf } },
+ { binding: 1, resource: { buffer: curIn } },
+ { binding: 2, resource: { buffer: chainBuf } },
+ { binding: 3, resource: { buffer: marshalParams } },
+ ],
+ });
+ const disjointBind = device.createBindGroup({
+ layout: layoutDisjoint,
+ entries: [
+ { binding: 0, resource: { buffer: chainBuf } },
+ { binding: 1, resource: { buffer: dummy } },
+ { binding: 2, resource: { buffer: tempOutBuf } },
+ { binding: 3, resource: { buffer: disjointParams } },
+ ],
+ });
+ const scatterBind = device.createBindGroup({
+ layout: layoutScatter,
+ entries: [
+ { binding: 0, resource: { buffer: scatterPlanBuf } },
+ { binding: 1, resource: { buffer: tempOutBuf } },
+ { binding: 2, resource: { buffer: curOut } },
+ { binding: 3, resource: { buffer: scatterParams } },
+ ],
+ });
+ let carryBind: GPUBindGroup | null = null;
+ if (plan.numCarries > 0) {
+ carryBind = device.createBindGroup({
+ layout: layoutCarry,
+ entries: [
+ { binding: 0, resource: { buffer: carryPlanBuf } },
+ { binding: 1, resource: { buffer: curIn } },
+ { binding: 2, resource: { buffer: curOut } },
+ { binding: 3, resource: { buffer: carryParams } },
+ ],
+ });
+ }
+
+ allPasses.push({ pipeline: marshalPipe, bind: marshalBind, numWgs });
+ allPasses.push({ pipeline: disjointPipe, bind: disjointBind, numWgs });
+ allPasses.push({ pipeline: scatterPipe, bind: scatterBind, numWgs });
+ if (plan.numCarries > 0 && carryBind) {
+ const carryWgs = Math.ceil(plan.numCarries / WGI);
+ allPasses.push({ pipeline: carryPipe, bind: carryBind, numWgs: carryWgs });
+ }
+ levelBufHolders.push(chunkPlanBuf, scatterPlanBuf, carryPlanBuf, marshalParams, disjointParams, scatterParams, carryParams);
+
+ counts = plan.newCounts;
+ offsets = plan.newOffsets;
+ [curIn, curOut] = [curOut, curIn];
+ levelIdx++;
+ }
+
+ const enc = device.createCommandEncoder();
+ for (const ps of allPasses) {
+ const pass = enc.beginComputePass();
+ pass.setPipeline(ps.pipeline);
+ pass.setBindGroup(0, ps.bind);
+ pass.dispatchWorkgroups(ps.numWgs, 1, 1);
+ pass.end();
+ }
+ const t0 = performance.now();
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ const gpuWall = performance.now() - t0;
+ log('info', `batched ${allPasses.length} passes in one submit: ${gpuWall.toFixed(2)} ms`);
+
+ const result = await readbackU32(device, curIn, 2 * PG * M * 4 * 4);
+
+ bufA.destroy();
+ bufB.destroy();
+ chainBuf.destroy();
+ tempOutBuf.destroy();
+ dummy.destroy();
+ for (const b of levelBufHolders) b.destroy();
+
+ const decodeAt = (slot: number): { x_mont: bigint; y_mont: bigint } => {
+ const xWords = new Uint32Array(8);
+ const yWords = new Uint32Array(8);
+ const planeBaseX = 0 * PG * M;
+ const planeBaseY = 1 * PG * M;
+ for (let v = 0; v < PG; v++) {
+ const baseX = (planeBaseX + PG * slot + v) * 4;
+ const baseY = (planeBaseY + PG * slot + v) * 4;
+ xWords[4 * v + 0] = result[baseX + 0];
+ xWords[4 * v + 1] = result[baseX + 1];
+ xWords[4 * v + 2] = result[baseX + 2];
+ xWords[4 * v + 3] = result[baseX + 3];
+ yWords[4 * v + 0] = result[baseY + 0];
+ yWords[4 * v + 1] = result[baseY + 1];
+ yWords[4 * v + 2] = result[baseY + 2];
+ yWords[4 * v + 3] = result[baseY + 3];
+ }
+ const x_mont = packedU32x8ToBigint(xWords, 0);
+ const y_mont = packedU32x8ToBigint(yWords, 0);
+ return { x_mont, y_mont };
+ };
+
+ const G1 = bn254.G1.ProjectivePoint;
+ const refSumPerBucket = new Map();
+ for (let b = 0; b < BUCKETS; b++) {
+ if (initCounts[b] === 0) continue;
+ let acc = G1.ZERO;
+ for (let i = 0; i < NPTS; i++) {
+ if (bucket[i] !== b) continue;
+ acc = acc.add(G1.fromAffine({ x: points[i].x, y: points[i].y }));
+ }
+ refSumPerBucket.set(b, acc.is0() ? null : acc.toAffine());
+ }
+
+ const checks: BucketCheck[] = [];
+ const mismatches: BucketCheck[] = [];
+ let passCount = 0;
+ for (let b = 0; b < BUCKETS; b++) {
+ if (initCounts[b] === 0) continue;
+ const slot = finalOffsets[b];
+ const { x_mont, y_mont } = decodeAt(slot);
+ const gx = (x_mont * Rinv) % p;
+ const gy = (y_mont * Rinv) % p;
+ const ref = refSumPerBucket.get(b);
+ let ok = false;
+ if (ref === null) {
+ ok = gx === 0n && gy === 0n;
+ } else if (ref) {
+ ok = gx === ref.x && gy === ref.y;
+ }
+ const entry: BucketCheck = {
+ bucket: b,
+ count: initCounts[b],
+ gpu_x: gx.toString(16),
+ gpu_y: gy.toString(16),
+ ref_x: ref ? ref.x.toString(16) : 'INF',
+ ref_y: ref ? ref.y.toString(16) : 'INF',
+ ok,
+ };
+ checks.push(entry);
+ if (ok) passCount++;
+ else if (mismatches.length < 8) mismatches.push(entry);
+ }
+ const allPassed = mismatches.length === 0 && passCount === checks.length;
+
+ if (allPassed) {
+ log('ok', `oracle PASS — ${passCount}/${checks.length} buckets match noble reference`);
+ } else {
+ log('err', `oracle FAIL — ${checks.length - passCount}/${checks.length} buckets diverged (showing first ${mismatches.length})`);
+ for (const m of mismatches) {
+ log('err', ` bucket ${m.bucket} (count=${m.count})`);
+ log('err', ` gpu: x=${m.gpu_x} y=${m.gpu_y}`);
+ log('err', ` ref: x=${m.ref_x} y=${m.ref_y}`);
+ }
+ }
+
+ return {
+ n: NPTS,
+ buckets: BUCKETS,
+ s: S,
+ wgi: WGI,
+ levels: levelIdx,
+ total_pair_adds: totalPairAdds,
+ buckets_checked: checks.length,
+ buckets_passed: passCount,
+ first_mismatches: mismatches,
+ all_passed: allPassed,
+ gpu_wall_ms: gpuWall,
+ };
+}
+
+function parseParams() {
+ const qp = new URLSearchParams(window.location.search);
+ if (qp.get('n')) NPTS = parseInt(qp.get('n')!, 10);
+ if (qp.get('buckets')) BUCKETS = parseInt(qp.get('buckets')!, 10);
+ if (qp.get('s')) S = parseInt(qp.get('s')!, 10);
+ if (qp.get('wgi')) WGI = parseInt(qp.get('wgi')!, 10);
+ if (qp.get('seed')) SEED = parseInt(qp.get('seed')!, 10);
+ return { n: NPTS, buckets: BUCKETS, s: S, wgi: WGI, seed: SEED };
+}
+
+async function main() {
+ try {
+ if (!('gpu' in navigator)) throw new Error('navigator.gpu missing');
+ const params = parseParams();
+ benchState.params = params;
+ log('info', `params: ${JSON.stringify(params)}`);
+ benchState.state = 'running';
+ const device = await get_device();
+ log('info', 'WebGPU device acquired');
+ const p = BN254_BASE_FIELD;
+ const miscParams = compute_misc_params(p, 13);
+ const R = miscParams.r;
+ const Rinv = modInverse(R, p);
+ const sm = new ShaderManager(1, BUCKETS, BN254_CURVE_CONFIG, false);
+ const r = await runOracle(device, sm, R, Rinv, p);
+ benchState.results.push(r);
+ resultsClient.postProgress({
+ kind: 'oracle_done',
+ all_passed: r.all_passed,
+ buckets_passed: r.buckets_passed,
+ buckets_checked: r.buckets_checked,
+ gpu_wall_ms: r.gpu_wall_ms,
+ });
+ benchState.state = 'done';
+ log('ok', 'done');
+ } catch (e) {
+ const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e);
+ log('err', `FATAL: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ }
+}
+
+main()
+ .catch(e => {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `unhandled: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ })
+ .finally(() => {
+ postFinal().catch(() => {});
+ });
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.html
new file mode 100644
index 000000000000..bd33e10887af
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.html
@@ -0,0 +1,22 @@
+
+
+
+
+ Bin-packed pair-tree MSM bucket-accumulate v2 (WebGPU)
+
+
+
+ Bin-packed pair-tree MSM bucket-accumulate v2 (WebGPU)
+ Query params: ?reps=R&n=N&buckets=B&s=S&wgi=W
+
+
+
+
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts
new file mode 100644
index 000000000000..b0f97f6081d8
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts
@@ -0,0 +1,566 @@
+///
+// bench-msm-tree-v2 — bin-packed pair-tree MSM bucket-accumulate with
+// carry-forward. Eliminates the slow per-bucket tail kernel by packing
+// pairs from any combination of buckets into the same chunk.
+//
+// For each (chunk t, slot k), the disjoint kernel sums (P_{2k}, P_{2k+1}).
+// Both operands of each pair come from the SAME bucket; different
+// (chunk, slot) entries can come from DIFFERENT buckets. The planner
+// guarantees the within-pair bucket invariant.
+//
+// Per level transition:
+// 1. host: per-bucket pair-count + carry, bin-pack into chunks of S.
+// 2. marshal-pairs (GPU): gather operands per chunk_plan into chain_buf.
+// 3. tree-disjoint (GPU, final=1): chain_buf -> simple strided output.
+// 4. scatter-pairs (GPU): outputs -> active_sums_new at per-bucket positions.
+// 5. carry-copy (GPU): odd-count carries -> active_sums_new (if any).
+// 6. swap active_sums buffers, update counts/offsets.
+//
+// Terminate when max bucket count == 1.
+
+import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js';
+import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js';
+import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js';
+import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js';
+import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js';
+import { makeResultsClient } from './results_post.js';
+
+const PG = 2;
+const DEFAULT_N = 1 << 17;
+const DEFAULT_BUCKETS = 1 << 12;
+const DEFAULT_S = 16;
+const DEFAULT_WGI = 64;
+
+let NPTS = DEFAULT_N;
+let BUCKETS = DEFAULT_BUCKETS;
+let S = DEFAULT_S;
+let WGI = DEFAULT_WGI;
+
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+function randomBelow(p: bigint, rng: () => number): bigint {
+ const bitlen = p.toString(2).length;
+ const byteLen = Math.ceil(bitlen / 8);
+ for (;;) {
+ let v = 0n;
+ for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff);
+ v &= (1n << BigInt(bitlen)) - 1n;
+ if (v > 0n && v < p) return v;
+ }
+}
+
+function bigintToPackedU32x8(v: bigint): Uint32Array {
+ const w = new Uint32Array(8);
+ let x = v;
+ for (let i = 0; i < 8; i++) {
+ w[i] = Number(x & 0xffffffffn);
+ x >>= 32n;
+ }
+ return w;
+}
+
+function makeSoABuf(device: GPUDevice, M: number, copyDst: boolean, copySrc: boolean): GPUBuffer {
+ const bytes = 2 * PG * M * 4 * 4;
+ let usage = GPUBufferUsage.STORAGE;
+ if (copyDst) usage |= GPUBufferUsage.COPY_DST;
+ if (copySrc) usage |= GPUBufferUsage.COPY_SRC;
+ return device.createBuffer({ size: bytes, usage });
+}
+
+// Build initial active_sums (Level 0). Points are assigned to random
+// buckets and laid out bucket-major (bucket b's points at active_sums
+// indices offsets[b] .. offsets[b]+counts[b]-1). The last 2 slots [M-2,
+// M-1] hold a "pad pair" with distinct x — used to fill chunk-tail
+// slots without divide-by-zero.
+function buildL0ActiveSums(N: number, B: number, R: bigint, p: bigint, rng: () => number) {
+ const M = N + 2;
+ const buf = new Uint32Array(2 * PG * M * 4);
+ // Generate N + 2 random points.
+ const xWords = new Uint32Array(8 * M);
+ const yWords = new Uint32Array(8 * M);
+ for (let i = 0; i < M; i++) {
+ const x = (randomBelow(p, rng) * R) % p;
+ const y = (randomBelow(p, rng) * R) % p;
+ xWords.set(bigintToPackedU32x8(x), 8 * i);
+ yWords.set(bigintToPackedU32x8(y), 8 * i);
+ }
+ // Ensure pad pair x's differ.
+ if (xWords[8 * (M - 2)] === xWords[8 * (M - 1)]) {
+ xWords[8 * (M - 1)] ^= 1;
+ }
+ // Bucket assignment via composed hi/lo (unsigned).
+ const bucket = new Uint32Array(N);
+ const counts = new Uint32Array(B);
+ for (let i = 0; i < N; i++) {
+ const hi = (rng() >>> 16) & 0xffff;
+ const lo = (rng() >>> 16) & 0xffff;
+ const v = hi * 0x10000 + lo;
+ const b = v % B;
+ bucket[i] = b;
+ counts[b]++;
+ }
+ const offsets = new Uint32Array(B + 1);
+ for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b];
+ const cursor = new Uint32Array(B);
+ const writeElem = (planeIdx: number, dstIdx: number, words: Uint32Array, srcOff: number) => {
+ for (let v = 0; v < PG; v++) {
+ const base = ((planeIdx * PG + v) * M + dstIdx) * 4;
+ buf[base + 0] = words[srcOff + 4 * v + 0];
+ buf[base + 1] = words[srcOff + 4 * v + 1];
+ buf[base + 2] = words[srcOff + 4 * v + 2];
+ buf[base + 3] = words[srcOff + 4 * v + 3];
+ }
+ };
+ for (let i = 0; i < N; i++) {
+ const b = bucket[i];
+ const dst = offsets[b] + cursor[b]++;
+ writeElem(0, dst, xWords, 8 * i);
+ writeElem(1, dst, yWords, 8 * i);
+ }
+ // Pad pair at indices M-2, M-1.
+ writeElem(0, M - 2, xWords, 8 * (M - 2));
+ writeElem(1, M - 2, yWords, 8 * (M - 2));
+ writeElem(0, M - 1, xWords, 8 * (M - 1));
+ writeElem(1, M - 1, yWords, 8 * (M - 1));
+ return { initBuf: buf, initCounts: counts, initOffsets: offsets, M };
+}
+
+// Bin-pack the per-bucket pairs into chunks of S. Returns the per-chunk
+// operand-index plan and the per-output destination plan, plus the
+// carry plan and next-level (counts, offsets).
+function buildLevelPlan(
+ counts: Uint32Array,
+ offsets: Uint32Array,
+ s: number,
+ padLIdx: number,
+ padRIdx: number,
+ discardIdx: number,
+) {
+ const B = counts.length;
+ let totalPairs = 0;
+ let totalCarries = 0;
+ const newCounts = new Uint32Array(B);
+ for (let b = 0; b < B; b++) {
+ const n = counts[b];
+ const p = Math.floor(n / 2);
+ const c = n & 1;
+ totalPairs += p;
+ totalCarries += c;
+ newCounts[b] = p + c;
+ }
+ const newOffsets = new Uint32Array(B + 1);
+ for (let b = 0; b < B; b++) newOffsets[b + 1] = newOffsets[b] + newCounts[b];
+
+ const numChunks = Math.max(1, Math.ceil(totalPairs / s));
+ const chunkPlan = new Uint32Array(2 * s * numChunks);
+ const scatterPlan = new Uint32Array(s * numChunks);
+ const carryPlan = new Uint32Array(2 * Math.max(1, totalCarries));
+
+ for (let i = 0; i < numChunks * s; i++) {
+ chunkPlan[2 * i + 0] = padLIdx;
+ chunkPlan[2 * i + 1] = padRIdx;
+ scatterPlan[i] = discardIdx;
+ }
+
+ let slot = 0;
+ let carryIdx = 0;
+ for (let b = 0; b < B; b++) {
+ const n = counts[b];
+ const p = Math.floor(n / 2);
+ for (let j = 0; j < p; j++) {
+ chunkPlan[2 * slot + 0] = offsets[b] + 2 * j;
+ chunkPlan[2 * slot + 1] = offsets[b] + 2 * j + 1;
+ scatterPlan[slot] = newOffsets[b] + j;
+ slot++;
+ }
+ if (n & 1) {
+ carryPlan[2 * carryIdx + 0] = offsets[b] + n - 1;
+ carryPlan[2 * carryIdx + 1] = newOffsets[b] + p;
+ carryIdx++;
+ }
+ }
+ return { chunkPlan, scatterPlan, carryPlan, newCounts, newOffsets, numChunks, numCarries: totalCarries, totalPairs };
+}
+
+interface LevelTiming {
+ T: number;
+ pairs: number;
+ carries: number;
+ marshal_ms: number;
+ disjoint_ms: number;
+ scatter_ms: number;
+ carry_ms: number;
+}
+
+interface RunResult {
+ s: number;
+ wgi: number;
+ pairs: number;
+ buckets: number;
+ levels: number;
+ total_pair_adds: number;
+ total_wall_ms: number;
+ level_timings: LevelTiming[];
+ ns_per_inpt: number;
+ sanity_ok: boolean;
+}
+
+interface BenchState {
+ state: 'boot' | 'running' | 'done' | 'error';
+ params: { reps: number; n: number; buckets: number; s: number; wgi: number } | null;
+ results: RunResult[];
+ error: string | null;
+ log: string[];
+}
+
+const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] };
+(window as unknown as { __bench: BenchState }).__bench = benchState;
+
+const resultsClient = makeResultsClient({ page: 'bench-msm-tree-v2' });
+(window as unknown as { __runId: string }).__runId = resultsClient.runId;
+
+async function postFinal(): Promise {
+ await resultsClient.postResults({
+ state: benchState.state, params: benchState.params, results: benchState.results,
+ error: benchState.error, log: benchState.log,
+ userAgent: navigator.userAgent, hardwareConcurrency: navigator.hardwareConcurrency,
+ });
+}
+
+const $log = document.getElementById('log') as HTMLDivElement;
+function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) {
+ const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : '';
+ const span = document.createElement('div');
+ span.className = cls;
+ span.textContent = msg;
+ $log.appendChild(span);
+ benchState.log.push(`[${level}] ${msg}`);
+ console.log(`[bench-msm-tree-v2] ${msg}`);
+}
+
+async function compileOne(device: GPUDevice, code: string, key: string, layout: GPUBindGroupLayout): Promise {
+ const module = device.createShaderModule({ code });
+ const info = await module.getCompilationInfo();
+ let hasError = false;
+ const errLines: string[] = [];
+ for (const m of info.messages) {
+ const line = `[shader ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`;
+ if (m.type === 'error') { console.error(line); log('err', line); errLines.push(line); hasError = true; }
+ else { console.warn(line); }
+ }
+ if (hasError) throw new Error(`WGSL compile failed for ${key}: ${errLines.slice(0, 4).join(' | ')}`);
+ return device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+}
+
+function ioLayout4(device: GPUDevice): GPUBindGroupLayout {
+ return device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+}
+
+async function readNonZero(device: GPUDevice, buf: GPUBuffer, u32Count: number): Promise {
+ const bytes = u32Count * 4;
+ const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST });
+ const enc = device.createCommandEncoder();
+ enc.copyBufferToBuffer(buf, 0, staging, 0, bytes);
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await staging.mapAsync(GPUMapMode.READ);
+ const u32 = new Uint32Array(staging.getMappedRange().slice(0));
+ staging.unmap();
+ staging.destroy();
+ for (let i = 0; i < u32.length; i++) if (u32[i] !== 0) return true;
+ return false;
+}
+
+interface PassSpec { pipeline: GPUComputePipeline; bind: GPUBindGroup; numWgs: number }
+
+// Encode multiple passes into one command encoder, submit once, await
+// once. Returns the total wall time. Submit-overhead is paid once
+// across all passes — the right way to measure a fused pipeline.
+async function timeBatched(device: GPUDevice, passes: PassSpec[]): Promise {
+ const enc = device.createCommandEncoder();
+ for (const p of passes) {
+ const pass = enc.beginComputePass();
+ pass.setPipeline(p.pipeline);
+ pass.setBindGroup(0, p.bind);
+ pass.dispatchWorkgroups(p.numWgs, 1, 1);
+ pass.end();
+ }
+ const t0 = performance.now();
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ return performance.now() - t0;
+}
+
+async function runPipeline(device: GPUDevice, sm: ShaderManager, reps: number, R: bigint, p: bigint): Promise {
+ log('info', `=== N=${NPTS} B=${BUCKETS} S=${S} WGI=${WGI}`);
+
+ const rng = makeRng(0x9111);
+ const { initBuf, initCounts, initOffsets, M } = buildL0ActiveSums(NPTS, BUCKETS, R, p, rng);
+ log('info', `built L0 active_sums: M=${M}`);
+
+ // Histogram peek
+ let maxC0 = 0, minC0 = NPTS, c0 = 0, smallC = 0;
+ for (let b = 0; b < initCounts.length; b++) {
+ if (initCounts[b] > maxC0) maxC0 = initCounts[b];
+ if (initCounts[b] < minC0) minC0 = initCounts[b];
+ if (initCounts[b] === 0) c0++;
+ if (initCounts[b] < 32) smallC++;
+ }
+ log('info', `bucket counts: min=${minC0} max=${maxC0} zero=${c0} small(<32)=${smallC}/${initCounts.length}`);
+
+ const padLIdx = M - 2;
+ const padRIdx = M - 1;
+ // Use a fixed discard slot: M-2 (same as padLIdx). The discarded output
+ // overwrites the pad pair on each level — fine because the pad pair is
+ // re-set by buildL0 once and never relied on for correctness; in
+ // subsequent levels the planner's pad selection still finds a valid
+ // distinct-x pair as long as M-2 and M-1 hold distinct-x data at start.
+ // Per level, we re-seed the pad slots by copying first two slots of
+ // active_sums_new... actually simpler: use NEW pad slots per level. We
+ // need slots that have distinct x in active_sums_new. For pure safety,
+ // we'll have the planner discard scatters always go to (M-2) and pad
+ // chunks always read from active_sums_OLD's (padLIdx, padRIdx) — which
+ // is still ping-pong-stable if we maintain pad slots in both buffers.
+ const discardIdx = M - 2;
+
+ // Two ping-pong active_sums buffers, sized M.
+ const bufA = makeSoABuf(device, M, true, true);
+ const bufB = makeSoABuf(device, M, true, true);
+ device.queue.writeBuffer(bufA, 0, initBuf);
+ // Mirror the pad pair into bufB so it's available when we ping-pong.
+ // Read M-2 and M-1 from initBuf and write them into bufB at the same slots.
+ const padPairBytes = new Uint32Array(2 * PG * 2 * 4);
+ for (let pl = 0; pl < 2; pl++) {
+ for (let v = 0; v < PG; v++) {
+ const baseSrc = ((pl * PG + v) * M + (M - 2)) * 4;
+ const baseDst = (pl * PG + v) * 2 * 4 + 0;
+ padPairBytes[baseDst + 0] = initBuf[baseSrc + 0];
+ padPairBytes[baseDst + 1] = initBuf[baseSrc + 1];
+ padPairBytes[baseDst + 2] = initBuf[baseSrc + 2];
+ padPairBytes[baseDst + 3] = initBuf[baseSrc + 3];
+ const baseSrc2 = ((pl * PG + v) * M + (M - 1)) * 4;
+ padPairBytes[baseDst + 4] = initBuf[baseSrc2 + 0];
+ padPairBytes[baseDst + 5] = initBuf[baseSrc2 + 1];
+ padPairBytes[baseDst + 6] = initBuf[baseSrc2 + 2];
+ padPairBytes[baseDst + 7] = initBuf[baseSrc2 + 3];
+ }
+ }
+ // padPairBytes layout: same SoA but with M=2. Write into both bufA pad
+ // region and bufB pad region. bufA already has them via initBuf; for
+ // bufB write a sparse pad pair via a small upload at the pad offset.
+ // Simpler: write the entire initBuf into bufB too (initial state matches).
+ device.queue.writeBuffer(bufB, 0, initBuf);
+
+ // Scratch buffers.
+ const maxL0Chunks = Math.ceil(NPTS / 2 / S) + 1;
+ const chainBuf = makeSoABuf(device, 2 * S * maxL0Chunks, false, false);
+ const tempOutBuf = makeSoABuf(device, S * maxL0Chunks, false, true);
+
+ // Compile all 4 pipelines.
+ const layoutMarshal = ioLayout4(device);
+ const layoutDisjoint = ioLayout4(device);
+ const layoutScatter = ioLayout4(device);
+ const layoutCarry = ioLayout4(device);
+ const marshalPipe = await compileOne(device, sm.gen_ba_marshal_pairs_bench_shader(WGI, S), `marshal-pairs-W${WGI}-S${S}`, layoutMarshal);
+ const disjointPipe = await compileOne(device, sm.gen_ba_pair_disjoint_tree_bench_shader(WGI, S), `disjoint-W${WGI}-S${S}`, layoutDisjoint);
+ const scatterPipe = await compileOne(device, sm.gen_ba_scatter_pairs_bench_shader(WGI, S), `scatter-pairs-W${WGI}-S${S}`, layoutScatter);
+ const carryPipe = await compileOne(device, sm.gen_ba_carry_copy_bench_shader(WGI), `carry-W${WGI}`, layoutCarry);
+ log('info', '4 pipelines compiled');
+
+ // Iterate.
+ let counts = initCounts;
+ let offsets = initOffsets;
+ let curIn: GPUBuffer = bufA;
+ let curOut: GPUBuffer = bufB;
+ let totalPairAdds = 0;
+ let levelIdx = 0;
+ const levelTimings: LevelTiming[] = [];
+ const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE });
+ // Pre-pass: enqueue plan uploads and encode all-level passes. The
+ // device queue processes the writeBuffer calls in order before the
+ // single submit; the GPU inserts barriers between dependent storage
+ // reads/writes within the encoder. One submit + one await amortises
+ // submit overhead across the entire bucket-accumulate.
+ const allPasses: PassSpec[] = [];
+
+ const startTime = performance.now();
+
+ for (;;) {
+ let maxCount = 0;
+ for (let b = 0; b < counts.length; b++) if (counts[b] > maxCount) maxCount = counts[b];
+ if (maxCount <= 1) break;
+ if (levelIdx > 24) throw new Error('exceeded safety level cap');
+
+ const plan = buildLevelPlan(counts, offsets, S, padLIdx, padRIdx, discardIdx);
+ totalPairAdds += plan.totalPairs;
+ const T = plan.numChunks;
+ const numWgs = Math.ceil(T / WGI);
+ log('info', `L${levelIdx}: T=${T} pairs=${plan.totalPairs} carries=${plan.numCarries} maxCount=${maxCount}`);
+
+ const chunkPlanBuf = device.createBuffer({ size: plan.chunkPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST });
+ device.queue.writeBuffer(chunkPlanBuf, 0, plan.chunkPlan);
+ const scatterPlanBuf = device.createBuffer({ size: plan.scatterPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST });
+ device.queue.writeBuffer(scatterPlanBuf, 0, plan.scatterPlan);
+ const carryPlanBuf = device.createBuffer({ size: plan.carryPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST });
+ device.queue.writeBuffer(carryPlanBuf, 0, plan.carryPlan);
+
+ const marshalParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST });
+ device.queue.writeBuffer(marshalParams, 0, new Uint32Array([T, M, 0, 0]));
+ const disjointParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST });
+ device.queue.writeBuffer(disjointParams, 0, new Uint32Array([2 * S * T, T, 1, 0])); // final_flag=1
+ const scatterParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST });
+ device.queue.writeBuffer(scatterParams, 0, new Uint32Array([T, M, 0, 0]));
+ const carryParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST });
+ if (plan.numCarries > 0) {
+ device.queue.writeBuffer(carryParams, 0, new Uint32Array([plan.numCarries, M, M, 0]));
+ }
+
+ const marshalBind = device.createBindGroup({
+ layout: layoutMarshal, entries: [
+ { binding: 0, resource: { buffer: chunkPlanBuf } },
+ { binding: 1, resource: { buffer: curIn } },
+ { binding: 2, resource: { buffer: chainBuf } },
+ { binding: 3, resource: { buffer: marshalParams } },
+ ],
+ });
+ const disjointBind = device.createBindGroup({
+ layout: layoutDisjoint, entries: [
+ { binding: 0, resource: { buffer: chainBuf } },
+ { binding: 1, resource: { buffer: dummy } },
+ { binding: 2, resource: { buffer: tempOutBuf } },
+ { binding: 3, resource: { buffer: disjointParams } },
+ ],
+ });
+ const scatterBind = device.createBindGroup({
+ layout: layoutScatter, entries: [
+ { binding: 0, resource: { buffer: scatterPlanBuf } },
+ { binding: 1, resource: { buffer: tempOutBuf } },
+ { binding: 2, resource: { buffer: curOut } },
+ { binding: 3, resource: { buffer: scatterParams } },
+ ],
+ });
+ let carryBind: GPUBindGroup | null = null;
+ if (plan.numCarries > 0) {
+ carryBind = device.createBindGroup({
+ layout: layoutCarry, entries: [
+ { binding: 0, resource: { buffer: carryPlanBuf } },
+ { binding: 1, resource: { buffer: curIn } },
+ { binding: 2, resource: { buffer: curOut } },
+ { binding: 3, resource: { buffer: carryParams } },
+ ],
+ });
+ }
+
+ // Stash level's passes into the outer pass list to be timed as a
+ // single batched submit across ALL levels.
+ allPasses.push({ pipeline: marshalPipe, bind: marshalBind, numWgs });
+ allPasses.push({ pipeline: disjointPipe, bind: disjointBind, numWgs });
+ allPasses.push({ pipeline: scatterPipe, bind: scatterBind, numWgs });
+ if (plan.numCarries > 0 && carryBind) {
+ const carryWgs = Math.ceil(plan.numCarries / WGI);
+ allPasses.push({ pipeline: carryPipe, bind: carryBind, numWgs: carryWgs });
+ }
+ levelTimings.push({
+ T, pairs: plan.totalPairs, carries: plan.numCarries,
+ marshal_ms: 0, disjoint_ms: 0, scatter_ms: 0, carry_ms: 0,
+ });
+ log('info', ` L${levelIdx} encoded (T=${T}, pairs=${plan.totalPairs})`);
+
+ // Cleanup level-local buffers.
+ chunkPlanBuf.destroy();
+ scatterPlanBuf.destroy();
+ carryPlanBuf.destroy();
+ marshalParams.destroy();
+ disjointParams.destroy();
+ scatterParams.destroy();
+ carryParams.destroy();
+
+ counts = plan.newCounts;
+ offsets = plan.newOffsets;
+ [curIn, curOut] = [curOut, curIn];
+ levelIdx++;
+ }
+
+ // Single batched submit for ALL levels.
+ const wallSubmit = performance.now();
+ const totalWall = await timeBatched(device, allPasses);
+ log('info', `batched ${allPasses.length} passes in one submit: ${totalWall.toFixed(2)}ms`);
+
+ const wall = performance.now() - startTime; // includes plan-build + upload + GPU time
+ const sanity = await readNonZero(device, curIn, 8);
+
+ bufA.destroy();
+ bufB.destroy();
+ chainBuf.destroy();
+ tempOutBuf.destroy();
+ dummy.destroy();
+
+ const nsPerInpt = (totalWall * 1e6) / NPTS;
+ log(
+ sanity ? 'ok' : 'err',
+ `pipeline: ${levelIdx} levels, ${totalPairAdds} pair-adds, single-submit GPU wall=${totalWall.toFixed(2)}ms, total incl plan-upload=${wall.toFixed(2)}ms, ns/in-pt=${nsPerInpt.toFixed(2)}, sanity=${sanity ? 'OK' : 'FAIL'}`,
+ );
+
+ return {
+ s: S, wgi: WGI, pairs: NPTS, buckets: BUCKETS, levels: levelIdx,
+ total_pair_adds: totalPairAdds, total_wall_ms: totalWall, level_timings: levelTimings,
+ ns_per_inpt: nsPerInpt, sanity_ok: sanity,
+ };
+}
+
+function parseParams() {
+ const qp = new URLSearchParams(window.location.search);
+ const reps = parseInt(qp.get('reps') ?? '3', 10);
+ if (!Number.isFinite(reps) || reps <= 0 || reps > 50) throw new Error(`?reps must be in (0, 50]`);
+ if (qp.get('n')) NPTS = parseInt(qp.get('n')!, 10);
+ if (qp.get('buckets')) BUCKETS = parseInt(qp.get('buckets')!, 10);
+ if (qp.get('s')) S = parseInt(qp.get('s')!, 10);
+ if (qp.get('wgi')) WGI = parseInt(qp.get('wgi')!, 10);
+ return { reps, n: NPTS, buckets: BUCKETS, s: S, wgi: WGI };
+}
+
+async function main() {
+ try {
+ if (!('gpu' in navigator)) throw new Error('navigator.gpu missing');
+ const params = parseParams();
+ benchState.params = params;
+ log('info', `params: reps=${params.reps} n=${params.n} buckets=${params.buckets} s=${params.s} wgi=${params.wgi}`);
+ benchState.state = 'running';
+ const device = await get_device();
+ log('info', 'WebGPU device acquired');
+ const p = BN254_BASE_FIELD;
+ const miscParams = compute_misc_params(p, 13);
+ const R = miscParams.r;
+ const sm = new ShaderManager(4, NPTS, BN254_CURVE_CONFIG, false);
+ const r = await runPipeline(device, sm, params.reps, R, p);
+ benchState.results.push(r);
+ resultsClient.postProgress({ kind: 'pipeline_done', ns_per_inpt: r.ns_per_inpt, sanity_ok: r.sanity_ok });
+ benchState.state = 'done';
+ log('ok', 'done');
+ } catch (e) {
+ const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e);
+ log('err', `FATAL: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ }
+}
+
+main()
+ .catch(e => { const msg = e instanceof Error ? e.message : String(e); log('err', `unhandled: ${msg}`); benchState.state = 'error'; benchState.error = msg; })
+ .finally(() => { postFinal().catch(() => {}); });
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.html
new file mode 100644
index 000000000000..8eaa024c2f13
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.html
@@ -0,0 +1,22 @@
+
+
+
+
+ v3 — GPU planner + fused super-kernel MSM bucket-accumulate
+
+
+
+ v3 — GPU planner + fused super-kernel MSM bucket-accumulate
+ Query params: ?n=N&buckets=B&s=S&wgi=W&levels=L
+
+
+
+
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.ts
new file mode 100644
index 000000000000..4b420b00ee7d
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.ts
@@ -0,0 +1,538 @@
+///
+// bench-msm-tree-v3 — GPU-side planner + fused super-kernel.
+//
+// Per level (all GPU, encoded into one command list):
+// 1. Reset totals atomic counter to 0.
+// 2. Pre-pad chunk_plan + scatter_plan + carry_plan to safe values.
+// 3. Planner kernel: 1 thread per bucket -> writes chunk_plan,
+// scatter_plan, carry_plan, new_counts, new_offsets via atomic
+// offset reservation.
+// 4. Fused super-kernel: marshal + disjoint + scatter in one pass.
+// Reads chunk_plan + scatter_plan + active_sums_old; writes
+// active_sums_new.
+// 5. Carry kernel: copies odd-count carries from active_sums_old to
+// active_sums_new.
+// 6. Swap active_sums buffers, swap counts/offsets buffers.
+//
+// Over-dispatch L_MAX levels (default 8 for Poisson(λ=32) where max
+// bucket count is ~50-60, requiring log2(60) = 6 levels). Extra levels
+// with all-count-1 input are no-ops at the kernel level (planner
+// produces zero pairs; fused kernel dispatched with 0 threads via
+// host-side numWgs=0; carry kernel just copies the single-element-
+// per-bucket data forward).
+//
+// Single submit across all 3*L_MAX kernel dispatches. Zero host-GPU
+// round-trips between scalar-decompose and final readback.
+
+import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js';
+import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js';
+import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js';
+import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js';
+import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js';
+import { makeResultsClient } from './results_post.js';
+
+const PG = 2;
+const DEFAULT_N = 1 << 17;
+const DEFAULT_BUCKETS = 1 << 12;
+const DEFAULT_S = 16;
+const DEFAULT_WGI = 64;
+const DEFAULT_LEVELS = 8;
+
+let NPTS = DEFAULT_N;
+let BUCKETS = DEFAULT_BUCKETS;
+let S = DEFAULT_S;
+let WGI = DEFAULT_WGI;
+let LEVELS = DEFAULT_LEVELS;
+
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+function randomBelow(p: bigint, rng: () => number): bigint {
+ const bitlen = p.toString(2).length;
+ const byteLen = Math.ceil(bitlen / 8);
+ for (;;) {
+ let v = 0n;
+ for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff);
+ v &= (1n << BigInt(bitlen)) - 1n;
+ if (v > 0n && v < p) return v;
+ }
+}
+
+function bigintToPackedU32x8(v: bigint): Uint32Array {
+ const w = new Uint32Array(8);
+ let x = v;
+ for (let i = 0; i < 8; i++) {
+ w[i] = Number(x & 0xffffffffn);
+ x >>= 32n;
+ }
+ return w;
+}
+
+// Build initial active_sums and per-bucket counts/offsets (level 0).
+function buildL0(N: number, B: number, R: bigint, p: bigint, rng: () => number) {
+ const M = N + 2;
+ const buf = new Uint32Array(2 * PG * M * 4);
+ const xWords = new Uint32Array(8 * M);
+ const yWords = new Uint32Array(8 * M);
+ for (let i = 0; i < M; i++) {
+ const x = (randomBelow(p, rng) * R) % p;
+ const y = (randomBelow(p, rng) * R) % p;
+ xWords.set(bigintToPackedU32x8(x), 8 * i);
+ yWords.set(bigintToPackedU32x8(y), 8 * i);
+ }
+ if (xWords[8 * (M - 2)] === xWords[8 * (M - 1)]) xWords[8 * (M - 1)] ^= 1;
+ const bucket = new Uint32Array(N);
+ const counts = new Uint32Array(B);
+ for (let i = 0; i < N; i++) {
+ const hi = (rng() >>> 16) & 0xffff;
+ const lo = (rng() >>> 16) & 0xffff;
+ const v = hi * 0x10000 + lo;
+ const b = v % B;
+ bucket[i] = b;
+ counts[b]++;
+ }
+ const offsets = new Uint32Array(B + 1);
+ for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b];
+ const cursor = new Uint32Array(B);
+ const writeElem = (planeIdx: number, dstIdx: number, words: Uint32Array, srcOff: number) => {
+ for (let v = 0; v < PG; v++) {
+ const base = ((planeIdx * PG + v) * M + dstIdx) * 4;
+ buf[base + 0] = words[srcOff + 4 * v + 0];
+ buf[base + 1] = words[srcOff + 4 * v + 1];
+ buf[base + 2] = words[srcOff + 4 * v + 2];
+ buf[base + 3] = words[srcOff + 4 * v + 3];
+ }
+ };
+ for (let i = 0; i < N; i++) {
+ const b = bucket[i];
+ const dst = offsets[b] + cursor[b]++;
+ writeElem(0, dst, xWords, 8 * i);
+ writeElem(1, dst, yWords, 8 * i);
+ }
+ writeElem(0, M - 2, xWords, 8 * (M - 2));
+ writeElem(1, M - 2, yWords, 8 * (M - 2));
+ writeElem(0, M - 1, xWords, 8 * (M - 1));
+ writeElem(1, M - 1, yWords, 8 * (M - 1));
+ return { initBuf: buf, initCounts: counts, initOffsets: offsets, M };
+}
+
+function makeSoABuf(device: GPUDevice, M: number, copyDst: boolean, copySrc: boolean): GPUBuffer {
+ const bytes = 2 * PG * M * 4 * 4;
+ let usage = GPUBufferUsage.STORAGE;
+ if (copyDst) usage |= GPUBufferUsage.COPY_DST;
+ if (copySrc) usage |= GPUBufferUsage.COPY_SRC;
+ return device.createBuffer({ size: bytes, usage });
+}
+
+interface RunResult {
+ s: number;
+ wgi: number;
+ pairs: number;
+ buckets: number;
+ levels_run: number;
+ gpu_wall_ms: number;
+ ns_per_inpt: number;
+ sanity_ok: boolean;
+}
+
+interface BenchState {
+ state: 'boot' | 'running' | 'done' | 'error';
+ params: { n: number; buckets: number; s: number; wgi: number; levels: number } | null;
+ results: RunResult[];
+ error: string | null;
+ log: string[];
+}
+
+const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] };
+(window as unknown as { __bench: BenchState }).__bench = benchState;
+
+const resultsClient = makeResultsClient({ page: 'bench-msm-tree-v3' });
+(window as unknown as { __runId: string }).__runId = resultsClient.runId;
+
+async function postFinal(): Promise {
+ await resultsClient.postResults({
+ state: benchState.state, params: benchState.params, results: benchState.results,
+ error: benchState.error, log: benchState.log,
+ userAgent: navigator.userAgent, hardwareConcurrency: navigator.hardwareConcurrency,
+ });
+}
+
+const $log = document.getElementById('log') as HTMLDivElement;
+function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) {
+ const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : '';
+ const span = document.createElement('div');
+ span.className = cls;
+ span.textContent = msg;
+ $log.appendChild(span);
+ benchState.log.push(`[${level}] ${msg}`);
+ console.log(`[bench-msm-tree-v3] ${msg}`);
+}
+
+async function compileOne(device: GPUDevice, code: string, key: string, layout: GPUBindGroupLayout): Promise {
+ const module = device.createShaderModule({ code });
+ const info = await module.getCompilationInfo();
+ let hasError = false;
+ const errLines: string[] = [];
+ for (const m of info.messages) {
+ const line = `[shader ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`;
+ if (m.type === 'error') { console.error(line); log('err', line); errLines.push(line); hasError = true; }
+ else { console.warn(line); }
+ }
+ if (hasError) throw new Error(`WGSL compile failed for ${key}: ${errLines.slice(0, 4).join(' | ')}`);
+ return device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+}
+
+async function readNonZero(device: GPUDevice, buf: GPUBuffer, u32Count: number): Promise {
+ const bytes = u32Count * 4;
+ const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST });
+ const enc = device.createCommandEncoder();
+ enc.copyBufferToBuffer(buf, 0, staging, 0, bytes);
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await staging.mapAsync(GPUMapMode.READ);
+ const u32 = new Uint32Array(staging.getMappedRange().slice(0));
+ staging.unmap();
+ staging.destroy();
+ for (let i = 0; i < u32.length; i++) if (u32[i] !== 0) return true;
+ return false;
+}
+
+async function runPipeline(device: GPUDevice, sm: ShaderManager, R: bigint, p: bigint): Promise {
+ log('info', `=== N=${NPTS} B=${BUCKETS} S=${S} WGI=${WGI} LEVELS=${LEVELS}`);
+
+ const rng = makeRng(0xc711);
+ const { initBuf, initCounts, initOffsets, M } = buildL0(NPTS, BUCKETS, R, p, rng);
+ log('info', `L0 active_sums: M=${M}, B=${BUCKETS}`);
+
+ // Histogram peek
+ let maxC = 0, minC = NPTS, smallC = 0;
+ for (let b = 0; b < initCounts.length; b++) {
+ if (initCounts[b] > maxC) maxC = initCounts[b];
+ if (initCounts[b] < minC) minC = initCounts[b];
+ if (initCounts[b] < 32) smallC++;
+ }
+ log('info', `bucket counts: min=${minC} max=${maxC} small(<32)=${smallC}/${BUCKETS}`);
+
+ // Plan-buffer sizing — must accommodate L0 max chunks.
+ const MAX_CHUNKS = Math.ceil(NPTS / 2 / S) + 16;
+ const MAX_PAIR_SLOTS = MAX_CHUNKS * S;
+ const MAX_CARRIES = BUCKETS;
+
+ // Host simulates the bin-packing iteration to compute the right
+ // dispatch size per level. The GPU planner does the same work to
+ // fill plan buffers; this host loop only computes sizes (T_chunks,
+ // T_carries) so the host can dispatch the fused + carry kernels at
+ // the correct size per level, avoiding pad-chunk waste.
+ //
+ // Without this, the fused kernel runs MAX_CHUNKS=4160 threads per
+ // level regardless of actual work. At L5 with only ~117 real chunks
+ // that's ~3979 pad-chunks each running a full fr_inv_by_a + S mont
+ // muls -- dominates wall time.
+ const perLevelTChunks: number[] = [];
+ const perLevelTCarries: number[] = [];
+ {
+ let cur = new Uint32Array(initCounts);
+ for (let lv = 0; lv < LEVELS; lv++) {
+ let totalPairs = 0;
+ let totalCarries = 0;
+ const next = new Uint32Array(cur.length);
+ for (let b = 0; b < cur.length; b++) {
+ const n = cur[b];
+ const p = (n / 2) | 0;
+ const c = n & 1;
+ totalPairs += p;
+ totalCarries += c;
+ next[b] = p + c;
+ }
+ perLevelTChunks.push(Math.max(1, Math.ceil(totalPairs / S)));
+ perLevelTCarries.push(Math.max(1, totalCarries));
+ cur = next;
+ }
+ log('info', `host-simulated per-level T_chunks: ${perLevelTChunks.join(', ')}`);
+ log('info', `host-simulated per-level T_carries: ${perLevelTCarries.join(', ')}`);
+ }
+
+ const mkStorage = (bytes: number, copyDst = true, copySrc = false): GPUBuffer => {
+ let usage = GPUBufferUsage.STORAGE;
+ if (copyDst) usage |= GPUBufferUsage.COPY_DST;
+ if (copySrc) usage |= GPUBufferUsage.COPY_SRC;
+ return device.createBuffer({ size: bytes, usage });
+ };
+
+ // Ping-pong active_sums.
+ const bufA = makeSoABuf(device, M, true, true);
+ const bufB = makeSoABuf(device, M, true, true);
+ device.queue.writeBuffer(bufA, 0, initBuf);
+ device.queue.writeBuffer(bufB, 0, initBuf); // mirror initial for pad-pair availability
+
+ // Plan buffers (reused per level).
+ const chunkPlanBuf = mkStorage(2 * MAX_PAIR_SLOTS * 4);
+ const scatterPlanBuf = mkStorage(MAX_PAIR_SLOTS * 4);
+ const carryPlanBuf = mkStorage(2 * MAX_CARRIES * 4);
+
+ // Per-level counts/offsets buffers (ping-pong).
+ const countsA = mkStorage(BUCKETS * 4);
+ const countsB = mkStorage(BUCKETS * 4);
+ const offsetsA = mkStorage((BUCKETS + 1) * 4);
+ const offsetsB = mkStorage((BUCKETS + 1) * 4);
+ device.queue.writeBuffer(countsA, 0, initCounts);
+ device.queue.writeBuffer(offsetsA, 0, initOffsets);
+
+ // Totals atomic counter [pair_off_accum, carry_off_accum, new_off_accum, _]
+ const totalsBuf = device.createBuffer({
+ size: 16,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
+ });
+
+ // Pre-padded chunk_plan / scatter_plan / carry_plan templates.
+ // Pad slots all point to safe values:
+ // chunk_plan pad pair = (M-2, M-1) — known distinct-x in active_sums
+ // scatter_plan pad dst = M-2 — discard target (within active_sums_new)
+ // carry_plan pad src = M-2, dst = M-2 — no-op self-copy of pad slot
+ const padChunkPlan = new Uint32Array(2 * MAX_PAIR_SLOTS);
+ const padScatterPlan = new Uint32Array(MAX_PAIR_SLOTS);
+ const padCarryPlan = new Uint32Array(2 * MAX_CARRIES);
+ for (let i = 0; i < MAX_PAIR_SLOTS; i++) {
+ padChunkPlan[2 * i + 0] = M - 2;
+ padChunkPlan[2 * i + 1] = M - 1;
+ padScatterPlan[i] = M - 2;
+ }
+ for (let i = 0; i < MAX_CARRIES; i++) {
+ padCarryPlan[2 * i + 0] = M - 2;
+ padCarryPlan[2 * i + 1] = M - 2;
+ }
+
+ // Per-level params (reused).
+ const plannerParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST });
+ const fusedParams: GPUBuffer[] = [];
+ const carryParams: GPUBuffer[] = [];
+ for (let i = 0; i < LEVELS; i++) {
+ fusedParams.push(device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }));
+ carryParams.push(device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }));
+ }
+ device.queue.writeBuffer(plannerParams, 0, new Uint32Array([BUCKETS, S, 0, 0]));
+
+ // Layouts.
+ const plannerLayout = device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 5, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 6, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 7, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 8, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+ const fusedLayout = device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+ const carryLayout = device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+
+ const plannerPipe = await compileOne(device, sm.gen_ba_planner_bench_shader(WGI, S, 64), `planner-W${WGI}-S${S}`, plannerLayout);
+ const fusedPipe = await compileOne(device, sm.gen_ba_fused_super_bench_shader(WGI, S), `fused-W${WGI}-S${S}`, fusedLayout);
+ const carryPipe = await compileOne(device, sm.gen_ba_carry_copy_bench_shader(WGI), `carry-W${WGI}`, carryLayout);
+ log('info', '3 pipelines compiled');
+
+ // Encode all level passes into one command encoder.
+ // Per level k:
+ // - clear totalsBuf to zero (writeBuffer queued before submit)
+ // - clear plan buffers to pad templates (writeBuffer queued)
+ // - dispatch planner
+ // - dispatch fused (numWgs = ceil(MAX_CHUNKS / WGI) — over-provisioned;
+ // idle threads early-out via if (t >= T) check)
+ // - dispatch carry (numWgs = ceil(MAX_CARRIES / WGI) — over-provisioned)
+ // - swap counts/offsets via bind group selection on next level
+
+ const enc = device.createCommandEncoder();
+ let curCountsIn: GPUBuffer = countsA;
+ let curCountsOut: GPUBuffer = countsB;
+ let curOffsetsIn: GPUBuffer = offsetsA;
+ let curOffsetsOut: GPUBuffer = offsetsB;
+ let curActiveIn: GPUBuffer = bufA;
+ let curActiveOut: GPUBuffer = bufB;
+
+ const numWgsPlanner = Math.ceil(BUCKETS / WGI);
+ const numWgsFusedPerLevel = perLevelTChunks.map(t => Math.ceil(t / WGI));
+ const numWgsCarryPerLevel = perLevelTCarries.map(t => Math.ceil(t / WGI));
+ log('info', `numWgs: planner=${numWgsPlanner}, fused=${numWgsFusedPerLevel.join(',')}, carry=${numWgsCarryPerLevel.join(',')}`);
+
+ // Per-level params with the right-sized T from the host bin-pack simulator.
+ for (let lv = 0; lv < LEVELS; lv++) {
+ device.queue.writeBuffer(fusedParams[lv], 0, new Uint32Array([perLevelTChunks[lv], M, M, 0]));
+ device.queue.writeBuffer(carryParams[lv], 0, new Uint32Array([perLevelTCarries[lv], M, M, 0]));
+ }
+
+ // Pre-pad plan buffers (only need to do once; planner overwrites real
+ // entries each level, and the pre-pad is stable across levels because
+ // pad slots remain pad-valued).
+ device.queue.writeBuffer(chunkPlanBuf, 0, padChunkPlan);
+ device.queue.writeBuffer(scatterPlanBuf, 0, padScatterPlan);
+ device.queue.writeBuffer(carryPlanBuf, 0, padCarryPlan);
+
+ // Bind groups built per level (to swap counts/offsets/active buffers).
+ for (let lv = 0; lv < LEVELS; lv++) {
+ // Reset totals atomic counter before this level's planner.
+ device.queue.writeBuffer(totalsBuf, 0, new Uint32Array([0, 0, 0, 0]));
+ // Re-pad plan buffers (planner overwrites only real entries; the
+ // pad regions get re-padded between levels to clean any leftover
+ // real entries from the prior level).
+ device.queue.writeBuffer(chunkPlanBuf, 0, padChunkPlan);
+ device.queue.writeBuffer(scatterPlanBuf, 0, padScatterPlan);
+ device.queue.writeBuffer(carryPlanBuf, 0, padCarryPlan);
+
+ const plannerBind = device.createBindGroup({
+ layout: plannerLayout,
+ entries: [
+ { binding: 0, resource: { buffer: curCountsIn } },
+ { binding: 1, resource: { buffer: curOffsetsIn } },
+ { binding: 2, resource: { buffer: chunkPlanBuf } },
+ { binding: 3, resource: { buffer: scatterPlanBuf } },
+ { binding: 4, resource: { buffer: carryPlanBuf } },
+ { binding: 5, resource: { buffer: totalsBuf } },
+ { binding: 6, resource: { buffer: curCountsOut } },
+ { binding: 7, resource: { buffer: curOffsetsOut } },
+ { binding: 8, resource: { buffer: plannerParams } },
+ ],
+ });
+ const fusedBind = device.createBindGroup({
+ layout: fusedLayout,
+ entries: [
+ { binding: 0, resource: { buffer: chunkPlanBuf } },
+ { binding: 1, resource: { buffer: scatterPlanBuf } },
+ { binding: 2, resource: { buffer: curActiveIn } },
+ { binding: 3, resource: { buffer: curActiveOut } },
+ { binding: 4, resource: { buffer: fusedParams[lv] } },
+ ],
+ });
+ const carryBind = device.createBindGroup({
+ layout: carryLayout,
+ entries: [
+ { binding: 0, resource: { buffer: carryPlanBuf } },
+ { binding: 1, resource: { buffer: curActiveIn } },
+ { binding: 2, resource: { buffer: curActiveOut } },
+ { binding: 3, resource: { buffer: carryParams[lv] } },
+ ],
+ });
+
+ // Encode the 3 passes for this level.
+ {
+ const pass = enc.beginComputePass();
+ pass.setPipeline(plannerPipe);
+ pass.setBindGroup(0, plannerBind);
+ pass.dispatchWorkgroups(numWgsPlanner, 1, 1);
+ pass.end();
+ }
+ {
+ const pass = enc.beginComputePass();
+ pass.setPipeline(fusedPipe);
+ pass.setBindGroup(0, fusedBind);
+ pass.dispatchWorkgroups(numWgsFusedPerLevel[lv], 1, 1);
+ pass.end();
+ }
+ {
+ const pass = enc.beginComputePass();
+ pass.setPipeline(carryPipe);
+ pass.setBindGroup(0, carryBind);
+ pass.dispatchWorkgroups(numWgsCarryPerLevel[lv], 1, 1);
+ pass.end();
+ }
+
+ // Swap for next level.
+ [curCountsIn, curCountsOut] = [curCountsOut, curCountsIn];
+ [curOffsetsIn, curOffsetsOut] = [curOffsetsOut, curOffsetsIn];
+ [curActiveIn, curActiveOut] = [curActiveOut, curActiveIn];
+ }
+
+ // Single submit + single await for the entire bucket-accumulate.
+ const t0 = performance.now();
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ const gpuWall = performance.now() - t0;
+
+ // Sanity: at least one element of the final active_sums must be non-zero.
+ const sanity = await readNonZero(device, curActiveIn, 8);
+ const nsPerInpt = (gpuWall * 1e6) / NPTS;
+
+ log(
+ sanity ? 'ok' : 'err',
+ `v3 pipeline: ${LEVELS} levels over-dispatched, single submit GPU wall=${gpuWall.toFixed(2)}ms, ns/in-pt=${nsPerInpt.toFixed(2)}, sanity=${sanity ? 'OK' : 'FAIL'}`,
+ );
+
+ // Cleanup
+ bufA.destroy(); bufB.destroy();
+ chunkPlanBuf.destroy(); scatterPlanBuf.destroy(); carryPlanBuf.destroy();
+ countsA.destroy(); countsB.destroy(); offsetsA.destroy(); offsetsB.destroy();
+ totalsBuf.destroy();
+ plannerParams.destroy();
+ for (const b of fusedParams) b.destroy();
+ for (const b of carryParams) b.destroy();
+
+ return {
+ s: S, wgi: WGI, pairs: NPTS, buckets: BUCKETS,
+ levels_run: LEVELS, gpu_wall_ms: gpuWall, ns_per_inpt: nsPerInpt, sanity_ok: sanity,
+ };
+}
+
+function parseParams() {
+ const qp = new URLSearchParams(window.location.search);
+ if (qp.get('n')) NPTS = parseInt(qp.get('n')!, 10);
+ if (qp.get('buckets')) BUCKETS = parseInt(qp.get('buckets')!, 10);
+ if (qp.get('s')) S = parseInt(qp.get('s')!, 10);
+ if (qp.get('wgi')) WGI = parseInt(qp.get('wgi')!, 10);
+ if (qp.get('levels')) LEVELS = parseInt(qp.get('levels')!, 10);
+ return { n: NPTS, buckets: BUCKETS, s: S, wgi: WGI, levels: LEVELS };
+}
+
+async function main() {
+ try {
+ if (!('gpu' in navigator)) throw new Error('navigator.gpu missing');
+ const params = parseParams();
+ benchState.params = params;
+ log('info', `params: n=${params.n} buckets=${params.buckets} s=${params.s} wgi=${params.wgi} levels=${params.levels}`);
+ benchState.state = 'running';
+ const device = await get_device();
+ log('info', 'WebGPU device acquired');
+ const p = BN254_BASE_FIELD;
+ const miscParams = compute_misc_params(p, 13);
+ const R = miscParams.r;
+ const sm = new ShaderManager(4, NPTS, BN254_CURVE_CONFIG, false);
+ const r = await runPipeline(device, sm, R, p);
+ benchState.results.push(r);
+ resultsClient.postProgress({ kind: 'pipeline_done', ns_per_inpt: r.ns_per_inpt, sanity_ok: r.sanity_ok });
+ benchState.state = 'done';
+ log('ok', 'done');
+ } catch (e) {
+ const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e);
+ log('err', `FATAL: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ }
+}
+
+main()
+ .catch(e => { const msg = e instanceof Error ? e.message : String(e); log('err', `unhandled: ${msg}`); benchState.state = 'error'; benchState.error = msg; })
+ .finally(() => { postFinal().catch(() => {}); });
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.html
new file mode 100644
index 000000000000..f64e6d3ff15a
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.html
@@ -0,0 +1,24 @@
+
+
+
+
+ MSM bucket-accumulate multi-level pair-tree bench (WebGPU)
+
+
+
+ MSM bucket-accumulate multi-level pair-tree bench (WebGPU)
+ Query params: ?reps=R&n=N&buckets=B&s=S&wgi=W&mode=uniform|skewed&disp=D
+
+
+
+
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts
new file mode 100644
index 000000000000..f8c05dfa79c3
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts
@@ -0,0 +1,790 @@
+///
+// bench-msm-tree — end-to-end MSM bucket-accumulate benchmark for the
+// multi-level pair-tree pipeline:
+// 1. marshal-l0 (CSR + chunk_plan + point_pool -> strided chain_buf)
+// 2. tree-disjoint level 0 (chain_buf -> level-1 input layout, in-place via ping-pong)
+// 3. tree-disjoint level 1 (continues in ping-pong)
+// ...
+// level L-1 (final): tree-disjoint with `final` flag -> simple strided output
+// 4. tail kernel (small buckets count<2*S -> one sum each)
+//
+// Reports per-stage and combined ns/in-pt for the full bucket-accumulate
+// over N points distributed across B buckets.
+//
+// Modes:
+// ?mode=uniform : every bucket has exactly 2*S = 32 points (clean
+// multi-level test, no tail).
+// ?mode=skewed : Poisson-distributed via uniform random scalar
+// assignment. Main pair-tree handles buckets with
+// count >= 2*S; tail kernel handles the rest.
+
+import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js';
+import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js';
+import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js';
+import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js';
+import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js';
+import { makeResultsClient } from './results_post.js';
+
+const PG = 2;
+const DEFAULT_N = 1 << 17; // 131072 points
+const DEFAULT_BUCKETS = 1 << 12; // 4096 buckets -> uniform avg 32 = 2*S
+const DEFAULT_S = 16;
+const DEFAULT_WGI = 64;
+const DEFAULT_DISP = 4; // dispatch amortisation per timed sample
+const DEFAULT_MODE = 'uniform' as const;
+
+let NPTS = DEFAULT_N;
+let BUCKETS = DEFAULT_BUCKETS;
+let S = DEFAULT_S;
+let WGI = DEFAULT_WGI;
+let DISP = DEFAULT_DISP;
+let MODE: 'uniform' | 'skewed' = DEFAULT_MODE;
+
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+function randomBelow(p: bigint, rng: () => number): bigint {
+ const bitlen = p.toString(2).length;
+ const byteLen = Math.ceil(bitlen / 8);
+ for (;;) {
+ let v = 0n;
+ for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff);
+ v &= (1n << BigInt(bitlen)) - 1n;
+ if (v > 0n && v < p) return v;
+ }
+}
+
+function bigintToPackedU32x8(v: bigint): Uint32Array {
+ const w = new Uint32Array(8);
+ let x = v;
+ for (let i = 0; i < 8; i++) {
+ w[i] = Number(x & 0xffffffffn);
+ x >>= 32n;
+ }
+ return w;
+}
+
+function median(xs: number[]): number {
+ if (xs.length === 0) return NaN;
+ const s = xs.slice().sort((a, b) => a - b);
+ return s[Math.floor(s.length / 2)];
+}
+
+function buildPointPool(poolSize: number, R: bigint, p: bigint, rng: () => number): Uint32Array {
+ const M = poolSize;
+ const buf = new Uint32Array(2 * PG * M * 4);
+ for (let e = 0; e < M; e++) {
+ const x = (randomBelow(p, rng) * R) % p;
+ const y = (randomBelow(p, rng) * R) % p;
+ const wx = bigintToPackedU32x8(x);
+ const wy = bigintToPackedU32x8(y);
+ for (let v = 0; v < PG; v++) {
+ const baseX = ((0 * PG + v) * M + e) * 4;
+ const baseY = ((1 * PG + v) * M + e) * 4;
+ buf[baseX + 0] = wx[4 * v + 0];
+ buf[baseX + 1] = wx[4 * v + 1];
+ buf[baseX + 2] = wx[4 * v + 2];
+ buf[baseX + 3] = wx[4 * v + 3];
+ buf[baseY + 0] = wy[4 * v + 0];
+ buf[baseY + 1] = wy[4 * v + 1];
+ buf[baseY + 2] = wy[4 * v + 2];
+ buf[baseY + 3] = wy[4 * v + 3];
+ }
+ }
+ return buf;
+}
+
+interface CSR {
+ csrIndices: Uint32Array; // 1-based: index 0 reserved (decoy/unused)
+ offsets: Uint32Array;
+ counts: Uint32Array;
+}
+
+function buildUniformCSR(N: number, B: number, perBucket: number): CSR {
+ if (N !== B * perBucket) {
+ throw new Error(`uniform mode requires N=${N} = B*${perBucket}=${B * perBucket}`);
+ }
+ const counts = new Uint32Array(B);
+ const offsets = new Uint32Array(B + 1);
+ const csrIndices = new Uint32Array(N);
+ for (let b = 0; b < B; b++) {
+ counts[b] = perBucket;
+ offsets[b + 1] = offsets[b] + perBucket;
+ for (let i = 0; i < perBucket; i++) {
+ csrIndices[offsets[b] + i] = b * perBucket + i + 1; // 1-based
+ }
+ }
+ return { csrIndices, offsets, counts };
+}
+
+function buildSkewedCSR(N: number, B: number, rng: () => number): CSR {
+ const bucket = new Uint32Array(N);
+ const counts = new Uint32Array(B);
+ // LCG low bits have short periods (low 12 bits cycle every 4096
+ // calls, which equals B in the default config), so direct modulo
+ // produces a degenerate "every bucket gets exactly N/B points"
+ // distribution. Compose a uniform 32-bit integer from the high
+ // halves of two RNG draws via unsigned arithmetic (multiplication
+ // to avoid the signed-i32 quirk of JS bitwise ops) and reduce.
+ for (let i = 0; i < N; i++) {
+ const hi = (rng() >>> 16) & 0xffff;
+ const lo = (rng() >>> 16) & 0xffff;
+ const v = hi * 0x10000 + lo;
+ const b = v % B;
+ bucket[i] = b;
+ counts[b]++;
+ }
+ const offsets = new Uint32Array(B + 1);
+ for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b];
+ const cursor = new Uint32Array(B);
+ const csrIndices = new Uint32Array(N);
+ for (let i = 0; i < N; i++) {
+ const b = bucket[i];
+ csrIndices[offsets[b] + cursor[b]++] = i + 1; // 1-based
+ }
+ return { csrIndices, offsets, counts };
+}
+
+// Split each bucket into (a) full 2*S chunks for the level-0 marshal +
+// pair-tree pass and (b) a tail of count mod 2*S points for the tail
+// kernel. Returns a chunk_plan for the main pipeline and a tail_plan
+// for the tail kernel.
+//
+// Only buckets where count >= 2*S contribute to the main path. Each
+// contributes floor(count / (2*S)) chunks of exactly 2*S points each.
+// Remaining count mod 2*S points go to the tail. Buckets with
+// count < 2*S go entirely to the tail. NOTE: for v1, the per-bucket
+// "extra full-2S chunks already reduced via main path" is *not*
+// re-folded into a single per-bucket sum; this would matter for
+// correctness but for the bench we just measure dispatch wall-clock.
+function buildChunkAndTailPlans(
+ offsets: Uint32Array,
+ counts: Uint32Array,
+ S: number,
+): { chunkPlan: Uint32Array; T: number; tailPlan: Uint32Array; TT: number; mainPoints: number; tailPoints: number } {
+ const B = counts.length;
+ const blkSize = 2 * S;
+ let T = 0;
+ let TT = 0;
+ let mainPts = 0;
+ let tailPts = 0;
+ for (let b = 0; b < B; b++) {
+ const c = counts[b];
+ const nMain = Math.floor(c / blkSize);
+ T += nMain;
+ mainPts += nMain * blkSize;
+ const remain = c - nMain * blkSize;
+ if (remain > 0) {
+ TT++;
+ tailPts += remain;
+ }
+ }
+ const chunkPlan = new Uint32Array(2 * T);
+ const tailPlan = new Uint32Array(3 * TT);
+ let t = 0;
+ let tt = 0;
+ for (let b = 0; b < B; b++) {
+ const c = counts[b];
+ const nMain = Math.floor(c / blkSize);
+ for (let k = 0; k < nMain; k++) {
+ chunkPlan[2 * t + 0] = b;
+ chunkPlan[2 * t + 1] = offsets[b] + k * blkSize;
+ t++;
+ }
+ const remain = c - nMain * blkSize;
+ if (remain > 0) {
+ tailPlan[3 * tt + 0] = b;
+ tailPlan[3 * tt + 1] = offsets[b] + nMain * blkSize;
+ tailPlan[3 * tt + 2] = remain;
+ tt++;
+ }
+ }
+ return { chunkPlan, T, tailPlan, TT, mainPoints: mainPts, tailPoints: tailPts };
+}
+
+interface KernelTiming {
+ median_ms: number;
+ min_ms: number;
+ max_ms: number;
+ samples_ms: number[];
+}
+
+async function compile(
+ device: GPUDevice,
+ code: string,
+ cacheKey: string,
+ layout: GPUBindGroupLayout,
+): Promise {
+ const module = device.createShaderModule({ code });
+ const info = await module.getCompilationInfo();
+ let hasError = false;
+ const errLines: string[] = [];
+ for (const m of info.messages) {
+ const line = `[shader ${cacheKey}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`;
+ if (m.type === 'error') {
+ console.error(line);
+ log('err', line);
+ errLines.push(line);
+ hasError = true;
+ } else {
+ console.warn(line);
+ }
+ }
+ if (hasError) throw new Error(`WGSL compile failed for ${cacheKey}: ${errLines.slice(0, 4).join(' | ')}`);
+ return device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+}
+
+function marshalLayout(device: GPUDevice): GPUBindGroupLayout {
+ return device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+}
+
+function treeKernelLayout(device: GPUDevice): GPUBindGroupLayout {
+ return device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+}
+
+function tailLayout(device: GPUDevice): GPUBindGroupLayout {
+ return device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+}
+
+async function readNonZero(device: GPUDevice, buf: GPUBuffer, u32Count: number): Promise {
+ const bytes = u32Count * 4;
+ const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST });
+ const enc = device.createCommandEncoder();
+ enc.copyBufferToBuffer(buf, 0, staging, 0, bytes);
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await staging.mapAsync(GPUMapMode.READ);
+ const u32 = new Uint32Array(staging.getMappedRange().slice(0));
+ staging.unmap();
+ staging.destroy();
+ for (let i = 0; i < u32.length; i++) if (u32[i] !== 0) return true;
+ return false;
+}
+
+async function timeDispatch(
+ device: GPUDevice,
+ pipeline: GPUComputePipeline,
+ bind: GPUBindGroup,
+ numWgs: number,
+ reps: number,
+ passes: number,
+): Promise {
+ // warmup
+ {
+ const enc = device.createCommandEncoder();
+ for (let p = 0; p < passes; p++) {
+ const pass = enc.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bind);
+ pass.dispatchWorkgroups(numWgs, 1, 1);
+ pass.end();
+ }
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ }
+ const samples: number[] = [];
+ for (let r = 0; r < reps; r++) {
+ const enc = device.createCommandEncoder();
+ for (let p = 0; p < passes; p++) {
+ const pass = enc.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bind);
+ pass.dispatchWorkgroups(numWgs, 1, 1);
+ pass.end();
+ }
+ const t0 = performance.now();
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ samples.push(performance.now() - t0);
+ }
+ return {
+ median_ms: median(samples),
+ min_ms: Math.min(...samples),
+ max_ms: Math.max(...samples),
+ samples_ms: samples,
+ };
+}
+
+interface RunResult {
+ s: number;
+ wgi: number;
+ disp: number;
+ pairs: number;
+ buckets: number;
+ mode: string;
+ T_main: number;
+ T_tail: number;
+ main_points: number;
+ tail_points: number;
+ levels: number;
+ marshal_ms: number;
+ level_ms: number[];
+ tail_ms: number;
+ total_ms: number;
+ marshal_ns_per_inpt: number;
+ pair_tree_ns_per_inpt: number;
+ tail_ns_per_inpt: number;
+ combined_ns_per_inpt: number;
+ sanity_ok: boolean;
+}
+
+interface BenchState {
+ state: 'boot' | 'running' | 'done' | 'error';
+ params: { reps: number; n: number; buckets: number; s: number; wgi: number; disp: number; mode: string } | null;
+ results: RunResult[];
+ error: string | null;
+ log: string[];
+}
+
+const benchState: BenchState = {
+ state: 'boot',
+ params: null,
+ results: [],
+ error: null,
+ log: [],
+};
+(window as unknown as { __bench: BenchState }).__bench = benchState;
+
+const resultsClient = makeResultsClient({ page: 'bench-msm-tree' });
+(window as unknown as { __runId: string }).__runId = resultsClient.runId;
+
+async function postFinal(): Promise {
+ await resultsClient.postResults({
+ state: benchState.state,
+ params: benchState.params,
+ results: benchState.results,
+ error: benchState.error,
+ log: benchState.log,
+ userAgent: navigator.userAgent,
+ hardwareConcurrency: navigator.hardwareConcurrency,
+ });
+}
+
+const $log = document.getElementById('log') as HTMLDivElement;
+function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) {
+ const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : '';
+ const span = document.createElement('div');
+ span.className = cls;
+ span.textContent = msg;
+ $log.appendChild(span);
+ benchState.log.push(`[${level}] ${msg}`);
+ console.log(`[bench-msm-tree] ${msg}`);
+}
+
+async function runPipeline(
+ device: GPUDevice,
+ sm: ShaderManager,
+ reps: number,
+ R: bigint,
+ p: bigint,
+): Promise {
+ log('info', `=== mode=${MODE} N=${NPTS} B=${BUCKETS} S=${S} WGI=${WGI} DISP=${DISP}`);
+
+ const rng = makeRng(0x4711);
+ const poolSize = NPTS + 1; // index 0 reserved (1-based)
+ const poolU32 = buildPointPool(poolSize, R, p, rng);
+
+ const csr: CSR =
+ MODE === 'uniform'
+ ? buildUniformCSR(NPTS, BUCKETS, NPTS / BUCKETS)
+ : buildSkewedCSR(NPTS, BUCKETS, rng);
+
+ const offsets = csr.offsets;
+ const counts = csr.counts;
+ const { chunkPlan, T, tailPlan, TT, mainPoints, tailPoints } = buildChunkAndTailPlans(offsets, counts, S);
+ log(
+ 'info',
+ `plan: main T=${T} chunks (${mainPoints} pts) | tail TT=${TT} threads (${tailPoints} pts) | dropped=${NPTS - mainPoints - tailPoints}`,
+ );
+ if (T === 0 && TT === 0) throw new Error('plan is empty');
+
+ // Determine number of pair-tree levels. Each level halves T and
+ // every level produces T*S outputs. We stop when T*S equals the
+ // distinct-bucket count contributing to the main path (= one sum
+ // per bucket). Iterating further would start pairing across
+ // buckets, which is incorrect.
+ let bMain = 0;
+ for (let b = 0; b < counts.length; b++) {
+ if (counts[b] >= 2 * S) bMain++;
+ }
+ if (bMain === 0) throw new Error('no buckets in main path');
+ const stopT = Math.max(1, Math.ceil(bMain / S));
+ const levels: number[] = [];
+ for (let t = T; t >= stopT; t = Math.floor(t / 2)) {
+ levels.push(t);
+ if (t === stopT) break;
+ if (levels.length > 24) throw new Error('too many tree levels');
+ }
+ log(
+ 'info',
+ `pair-tree levels: ${levels.length} (T sequence: ${levels.join(' -> ')}, stopT=${stopT}, bMain=${bMain})`,
+ );
+
+ // Buffers.
+ const mkSb = (size: number, copyDst: boolean, copySrc: boolean): GPUBuffer => {
+ let usage = GPUBufferUsage.STORAGE;
+ if (copyDst) usage |= GPUBufferUsage.COPY_DST;
+ if (copySrc) usage |= GPUBufferUsage.COPY_SRC;
+ return device.createBuffer({ size, usage });
+ };
+
+ const poolBuf = mkSb(poolU32.byteLength, true, false);
+ device.queue.writeBuffer(poolBuf, 0, poolU32);
+ let csrBuf: GPUBuffer | null = null;
+ let chunkBuf: GPUBuffer | null = null;
+ let tailPlanBuf: GPUBuffer | null = null;
+ if (T > 0) {
+ csrBuf = mkSb(csr.csrIndices.byteLength, true, false);
+ device.queue.writeBuffer(csrBuf, 0, csr.csrIndices);
+ chunkBuf = mkSb(chunkPlan.byteLength, true, false);
+ device.queue.writeBuffer(chunkBuf, 0, chunkPlan);
+ } else if (TT > 0) {
+ csrBuf = mkSb(csr.csrIndices.byteLength, true, false);
+ device.queue.writeBuffer(csrBuf, 0, csr.csrIndices);
+ }
+ if (TT > 0) {
+ tailPlanBuf = mkSb(tailPlan.byteLength, true, false);
+ device.queue.writeBuffer(tailPlanBuf, 0, tailPlan);
+ }
+
+ // Ping-pong buffers for the pair-tree. Each plane needs at most
+ // 2*S*T_0 vec4 (level-0 input size). Output of level k is sized
+ // S*T_k vec4 per plane, which is half. Use two buffers of the same
+ // size and ping-pong.
+ let bufA: GPUBuffer | null = null;
+ let bufB: GPUBuffer | null = null;
+ if (T > 0) {
+ const planeBytes = 2 * PG * (2 * S * T) * 4 * 4; // 2 planes (P.x, P.y) * PG vec4 * (2*S*T elems) * 4 u32/vec4 * 4 B
+ bufA = mkSb(planeBytes, false, true);
+ bufB = mkSb(planeBytes, false, true);
+ }
+
+ // Bucket-sums buffer (tail output): 2 planes (P.x, P.y), PG vec4 per
+ // bucket. Pre-zero implicitly (GPU buffers start zeroed in WebGPU).
+ const bucketSumsBytes = 2 * PG * BUCKETS * 4 * 4;
+ const bucketSumsBuf = mkSb(bucketSumsBytes, false, true);
+
+ const paramsBytes = 16;
+ const marshalParams = device.createBuffer({ size: paramsBytes, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST });
+ const levelParams: GPUBuffer[] = [];
+ for (let i = 0; i < levels.length; i++) {
+ levelParams.push(device.createBuffer({ size: paramsBytes, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }));
+ }
+ const tailParams = device.createBuffer({ size: paramsBytes, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST });
+
+ // Compile + bind.
+ let marshalPipeline: GPUComputePipeline | null = null;
+ let marshalBind: GPUBindGroup | null = null;
+ let marshalWgs = 0;
+ if (T > 0 && bufA !== null && csrBuf !== null && chunkBuf !== null) {
+ const code = sm.gen_ba_marshal_tree_l0_bench_shader(WGI, S);
+ log('info', `marshal-l0 shader: ${code.length} chars`);
+ const mL = marshalLayout(device);
+ marshalPipeline = await compile(device, code, `marshal-l0-W${WGI}-S${S}`, mL);
+ marshalBind = device.createBindGroup({
+ layout: mL,
+ entries: [
+ { binding: 0, resource: { buffer: csrBuf } },
+ { binding: 1, resource: { buffer: chunkBuf } },
+ { binding: 2, resource: { buffer: poolBuf } },
+ { binding: 3, resource: { buffer: bufA } },
+ { binding: 4, resource: { buffer: marshalParams } },
+ ],
+ });
+ device.queue.writeBuffer(marshalParams, 0, new Uint32Array([T, poolSize, 0, 0]));
+ marshalWgs = Math.ceil(T / WGI);
+ }
+
+ // Tree kernel: compile once, bind per-level with the appropriate
+ // (input, output, params) trio.
+ let treePipeline: GPUComputePipeline | null = null;
+ const treeBinds: GPUBindGroup[] = [];
+ const treeNumWgs: number[] = [];
+ if (T > 0 && bufA !== null && bufB !== null) {
+ const code = sm.gen_ba_pair_disjoint_tree_bench_shader(WGI, S);
+ log('info', `tree-disjoint shader: ${code.length} chars`);
+ const tL = treeKernelLayout(device);
+ treePipeline = await compile(device, code, `tree-disjoint-W${WGI}-S${S}`, tL);
+ const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE });
+ let curIn = bufA;
+ let curOut = bufB;
+ for (let lv = 0; lv < levels.length; lv++) {
+ const T_lv = levels[lv];
+ const N_in_lv = 2 * S * T_lv;
+ const isFinal = lv === levels.length - 1 ? 1 : 0;
+ device.queue.writeBuffer(levelParams[lv], 0, new Uint32Array([N_in_lv, T_lv, isFinal, 0]));
+ treeBinds.push(
+ device.createBindGroup({
+ layout: tL,
+ entries: [
+ { binding: 0, resource: { buffer: curIn } },
+ { binding: 1, resource: { buffer: dummy } },
+ { binding: 2, resource: { buffer: curOut } },
+ { binding: 3, resource: { buffer: levelParams[lv] } },
+ ],
+ }),
+ );
+ treeNumWgs.push(Math.ceil(T_lv / WGI));
+ const swap = curIn;
+ curIn = curOut;
+ curOut = swap;
+ }
+ }
+
+ // Tail pipeline.
+ let tailPipeline: GPUComputePipeline | null = null;
+ let tailBind: GPUBindGroup | null = null;
+ let tailWgs = 0;
+ if (TT > 0 && csrBuf !== null && tailPlanBuf !== null) {
+ const code = sm.gen_ba_tail_reduce_bench_shader(WGI, S);
+ log('info', `tail shader: ${code.length} chars`);
+ const tL = tailLayout(device);
+ tailPipeline = await compile(device, code, `tail-W${WGI}-S${S}`, tL);
+ tailBind = device.createBindGroup({
+ layout: tL,
+ entries: [
+ { binding: 0, resource: { buffer: csrBuf } },
+ { binding: 1, resource: { buffer: tailPlanBuf } },
+ { binding: 2, resource: { buffer: poolBuf } },
+ { binding: 3, resource: { buffer: bucketSumsBuf } },
+ { binding: 4, resource: { buffer: tailParams } },
+ ],
+ });
+ device.queue.writeBuffer(tailParams, 0, new Uint32Array([TT, poolSize, BUCKETS, 0]));
+ tailWgs = Math.ceil(TT / WGI);
+ }
+
+ // Warmup once.
+ if (marshalPipeline && marshalBind) {
+ const enc = device.createCommandEncoder();
+ const pass = enc.beginComputePass();
+ pass.setPipeline(marshalPipeline);
+ pass.setBindGroup(0, marshalBind);
+ pass.dispatchWorkgroups(marshalWgs, 1, 1);
+ pass.end();
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ }
+ if (treePipeline) {
+ for (let lv = 0; lv < treeBinds.length; lv++) {
+ const enc = device.createCommandEncoder();
+ const pass = enc.beginComputePass();
+ pass.setPipeline(treePipeline);
+ pass.setBindGroup(0, treeBinds[lv]);
+ pass.dispatchWorkgroups(treeNumWgs[lv], 1, 1);
+ pass.end();
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ }
+ }
+ if (tailPipeline && tailBind) {
+ const enc = device.createCommandEncoder();
+ const pass = enc.beginComputePass();
+ pass.setPipeline(tailPipeline);
+ pass.setBindGroup(0, tailBind);
+ pass.dispatchWorkgroups(tailWgs, 1, 1);
+ pass.end();
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ }
+
+ // Time each stage separately, DISP back-to-back per sample.
+ let marshalTiming: KernelTiming | null = null;
+ if (marshalPipeline && marshalBind) {
+ marshalTiming = await timeDispatch(device, marshalPipeline, marshalBind, marshalWgs, reps, DISP);
+ }
+ const levelTimings: KernelTiming[] = [];
+ if (treePipeline) {
+ for (let lv = 0; lv < treeBinds.length; lv++) {
+ const t = await timeDispatch(device, treePipeline, treeBinds[lv], treeNumWgs[lv], reps, DISP);
+ levelTimings.push(t);
+ }
+ }
+ let tailTiming: KernelTiming | null = null;
+ if (tailPipeline && tailBind) {
+ tailTiming = await timeDispatch(device, tailPipeline, tailBind, tailWgs, reps, DISP);
+ }
+
+ // Sanity: at least one of the output buffers must have nonzero data.
+ let sanity = false;
+ if (treePipeline && bufA && bufB) {
+ const finalBuf = treeBinds.length % 2 === 0 ? bufA : bufB;
+ sanity = sanity || (await readNonZero(device, finalBuf, 8));
+ }
+ if (tailPipeline) {
+ sanity = sanity || (await readNonZero(device, bucketSumsBuf, 8));
+ }
+
+ // Compute per-stage ns/in-pt (normalised to total points fed to that
+ // stage; DISP-amortised wall clock).
+ const totalInPts = mainPoints + tailPoints;
+ const marshalMed = marshalTiming?.median_ms ?? 0;
+ const marshalNs = (marshalMed * 1e6) / (mainPoints * DISP);
+ const treeTotalMed = levelTimings.reduce((acc, t) => acc + t.median_ms, 0);
+ const treeNs = (treeTotalMed * 1e6) / (mainPoints * DISP);
+ const tailMed = tailTiming?.median_ms ?? 0;
+ const tailNs = TT > 0 ? (tailMed * 1e6) / (tailPoints * DISP) : 0;
+ const combinedTotal = marshalMed + treeTotalMed + tailMed;
+ const combinedNs = (combinedTotal * 1e6) / (totalInPts * DISP);
+
+ log(
+ sanity ? 'ok' : 'err',
+ `marshal=${marshalNs.toFixed(2)}ns/pt pair_tree=${treeNs.toFixed(2)}ns/pt tail=${tailNs.toFixed(2)}ns/pt | combined=${combinedNs.toFixed(2)}ns/in-pt | sanity=${sanity ? 'OK' : 'FAIL'}`,
+ );
+ for (let lv = 0; lv < levelTimings.length; lv++) {
+ log('info', ` level ${lv} T=${levels[lv]}: median=${levelTimings[lv].median_ms.toFixed(3)}ms min=${levelTimings[lv].min_ms.toFixed(3)}ms`);
+ }
+
+ // Cleanup
+ poolBuf.destroy();
+ csrBuf?.destroy();
+ chunkBuf?.destroy();
+ tailPlanBuf?.destroy();
+ bufA?.destroy();
+ bufB?.destroy();
+ bucketSumsBuf.destroy();
+ marshalParams.destroy();
+ for (const b of levelParams) b.destroy();
+ tailParams.destroy();
+
+ return {
+ s: S,
+ wgi: WGI,
+ disp: DISP,
+ pairs: NPTS,
+ buckets: BUCKETS,
+ mode: MODE,
+ T_main: T,
+ T_tail: TT,
+ main_points: mainPoints,
+ tail_points: tailPoints,
+ levels: levelTimings.length,
+ marshal_ms: marshalMed,
+ level_ms: levelTimings.map(t => t.median_ms),
+ tail_ms: tailMed,
+ total_ms: combinedTotal,
+ marshal_ns_per_inpt: marshalNs,
+ pair_tree_ns_per_inpt: treeNs,
+ tail_ns_per_inpt: tailNs,
+ combined_ns_per_inpt: combinedNs,
+ sanity_ok: sanity,
+ };
+}
+
+function parseParams() {
+ const qp = new URLSearchParams(window.location.search);
+ const reps = parseInt(qp.get('reps') ?? '5', 10);
+ if (!Number.isFinite(reps) || reps <= 0 || reps > 50) throw new Error(`?reps must be in (0, 50]`);
+ if (qp.get('n')) {
+ const v = parseInt(qp.get('n')!, 10);
+ if (!Number.isFinite(v) || v <= 0 || v > (1 << 20)) throw new Error(`?n must be in (0, 2^20]`);
+ NPTS = v;
+ }
+ if (qp.get('buckets')) {
+ const v = parseInt(qp.get('buckets')!, 10);
+ if (!Number.isFinite(v) || v <= 0 || v > (1 << 18)) throw new Error(`?buckets must be in (0, 2^18]`);
+ BUCKETS = v;
+ }
+ if (qp.get('s')) {
+ const v = parseInt(qp.get('s')!, 10);
+ if (!Number.isFinite(v) || v <= 0 || v > 256) throw new Error(`?s must be in (0, 256]`);
+ S = v;
+ }
+ if (qp.get('wgi')) {
+ const v = parseInt(qp.get('wgi')!, 10);
+ if (!Number.isFinite(v) || v <= 0 || v > 1024) throw new Error(`?wgi must be in (0, 1024]`);
+ WGI = v;
+ }
+ if (qp.get('disp')) {
+ const v = parseInt(qp.get('disp')!, 10);
+ if (!Number.isFinite(v) || v <= 0 || v > 64) throw new Error(`?disp must be in (0, 64]`);
+ DISP = v;
+ }
+ if (qp.get('mode')) {
+ const v = qp.get('mode')!;
+ if (v !== 'uniform' && v !== 'skewed') throw new Error(`?mode must be uniform or skewed`);
+ MODE = v;
+ }
+ return { reps, n: NPTS, buckets: BUCKETS, s: S, wgi: WGI, disp: DISP, mode: MODE };
+}
+
+async function main() {
+ try {
+ if (!('gpu' in navigator)) throw new Error('navigator.gpu missing — WebGPU not available');
+ const params = parseParams();
+ benchState.params = params;
+ log(
+ 'info',
+ `params: reps=${params.reps} n=${params.n} buckets=${params.buckets} s=${params.s} wgi=${params.wgi} disp=${params.disp} mode=${params.mode}`,
+ );
+
+ benchState.state = 'running';
+ const device = await get_device();
+ log('info', 'WebGPU device acquired');
+
+ const p = BN254_BASE_FIELD;
+ const miscParams = compute_misc_params(p, 13);
+ const R = miscParams.r;
+
+ const sm = new ShaderManager(4, NPTS, BN254_CURVE_CONFIG, false);
+
+ const r = await runPipeline(device, sm, params.reps, R, p);
+ benchState.results.push(r);
+ resultsClient.postProgress({
+ kind: 'pipeline_done',
+ mode: r.mode,
+ combined_ns_per_inpt: r.combined_ns_per_inpt,
+ sanity_ok: r.sanity_ok,
+ });
+
+ benchState.state = 'done';
+ log('ok', 'pipeline complete');
+ } catch (e) {
+ const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e);
+ log('err', `FATAL: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ }
+}
+
+main()
+ .catch(e => {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `unhandled: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ })
+ .finally(() => {
+ postFinal().catch(() => {});
+ });
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-planner.html b/barretenberg/ts/dev/msm-webgpu/bench-planner.html
new file mode 100644
index 000000000000..dd773cd4d504
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-planner.html
@@ -0,0 +1,22 @@
+
+
+
+
+ Standalone GPU planner microbench (WebGPU)
+
+
+
+ Standalone GPU planner microbench
+ Query params: ?buckets=B&lambda=L&s=S&tpb=T&per=P&disp=D&reps=R&validate=1
+
+
+
+
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-planner.ts b/barretenberg/ts/dev/msm-webgpu/bench-planner.ts
new file mode 100644
index 000000000000..210da3c73b83
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-planner.ts
@@ -0,0 +1,423 @@
+///
+// Standalone microbench for the GPU bin-packing planner kernel.
+// Isolates the planner from the rest of the MSM pipeline so we can
+// pin down the minimum time required to build a per-level plan
+// (chunk_plan + scatter_plan + carry_plan + new_counts/offsets) on
+// the GPU.
+//
+// Inputs (synthetic, host-built upfront):
+// counts[B] per-bucket active count drawn from Poisson(lambda).
+// offsets[B+1] prefix sum of counts.
+//
+// Each planner dispatch is one workgroup of TPB threads. Each thread
+// handles PER_THREAD buckets. B = TPB * PER_THREAD.
+//
+// Timing methodology:
+// - Compile pipeline.
+// - Warmup (1 dispatch).
+// - For each rep:
+// Encode DISP back-to-back dispatches in ONE command encoder.
+// performance.now() right before submit and after await.
+// - Per-planner time = sample / DISP.
+// - Report min, median, max across reps.
+//
+// Validation (?validate=1):
+// - Run 1 dispatch.
+// - Read back chunk_plan, scatter_plan, carry_plan, new_counts,
+// new_offsets, totals.
+// - Cross-check against a host-side bin-pack reference.
+
+import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js';
+import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js';
+import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js';
+import { makeResultsClient } from './results_post.js';
+
+let BUCKETS = 4096;
+let LAMBDA = 32; // mean per-bucket count
+let S = 16;
+let TPB = 256;
+let PER_THREAD = 16; // BUCKETS / TPB
+let DISP = 128;
+let REPS = 5;
+let VALIDATE = false;
+
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+// Approximate Poisson(λ) sample via the Knuth method. Adequate for
+// generating synthetic bucket-count distributions in [0, ~3λ].
+function poisson(lambda: number, rng: () => number): number {
+ const L = Math.exp(-lambda);
+ let k = 0;
+ let p = 1.0;
+ for (;;) {
+ k++;
+ const u = (rng() >>> 0) / 0x100000000;
+ p *= u;
+ if (p <= L) break;
+ if (k > 200) break;
+ }
+ return k - 1;
+}
+
+function buildSyntheticCounts(B: number, lambda: number, seed: number): { counts: Uint32Array; offsets: Uint32Array } {
+ const rng = makeRng(seed);
+ const counts = new Uint32Array(B);
+ for (let b = 0; b < B; b++) counts[b] = poisson(lambda, rng);
+ const offsets = new Uint32Array(B + 1);
+ for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b];
+ return { counts, offsets };
+}
+
+// Host-side bin-pack reference. Returns the EXACT same outputs the GPU
+// planner is expected to produce (modulo per-bucket atomic-ordering
+// differences, which the v2 planner avoids — its order matches host).
+function buildHostReference(counts: Uint32Array, offsets: Uint32Array, S: number) {
+ const B = counts.length;
+ let totalPairs = 0;
+ let totalCarries = 0;
+ let totalNew = 0;
+ const newCounts = new Uint32Array(B);
+ const newOffsets = new Uint32Array(B + 1);
+ // First pass: compute new_counts and accumulate totals.
+ for (let b = 0; b < B; b++) {
+ const n = counts[b];
+ const pc = Math.floor(n / 2);
+ const cf = n & 1;
+ newCounts[b] = pc + cf;
+ totalPairs += pc;
+ totalCarries += cf;
+ totalNew += pc + cf;
+ }
+ for (let b = 0; b < B; b++) newOffsets[b + 1] = newOffsets[b] + newCounts[b];
+ const numChunks = Math.max(1, Math.ceil(totalPairs / S));
+ const chunkPlan = new Uint32Array(2 * numChunks * S);
+ const scatterPlan = new Uint32Array(numChunks * S);
+ const carryPlan = new Uint32Array(2 * Math.max(1, totalCarries));
+ let pairOff = 0;
+ let carryOff = 0;
+ for (let b = 0; b < B; b++) {
+ const n = counts[b];
+ const pc = Math.floor(n / 2);
+ const cf = n & 1;
+ const bucketBase = offsets[b];
+ for (let j = 0; j < pc; j++) {
+ const slot = pairOff + j;
+ const chunkId = Math.floor(slot / S);
+ const slotInChunk = slot % S;
+ const cpBase = 2 * (chunkId * S + slotInChunk);
+ chunkPlan[cpBase + 0] = bucketBase + 2 * j;
+ chunkPlan[cpBase + 1] = bucketBase + 2 * j + 1;
+ scatterPlan[chunkId * S + slotInChunk] = newOffsets[b] + j;
+ }
+ if (cf) {
+ carryPlan[2 * carryOff + 0] = bucketBase + n - 1;
+ carryPlan[2 * carryOff + 1] = newOffsets[b] + pc;
+ carryOff++;
+ }
+ pairOff += pc;
+ }
+ return { chunkPlan, scatterPlan, carryPlan, newCounts, newOffsets, totalPairs, totalCarries, totalNew, numChunks };
+}
+
+function median(xs: number[]): number {
+ if (xs.length === 0) return NaN;
+ const s = xs.slice().sort((a, b) => a - b);
+ return s[Math.floor(s.length / 2)];
+}
+
+interface BenchResult {
+ buckets: number;
+ lambda: number;
+ s: number;
+ tpb: number;
+ per_thread: number;
+ disp: number;
+ reps: number;
+ total_pairs: number;
+ total_carries: number;
+ num_chunks: number;
+ per_dispatch_us: { min: number; median: number; max: number };
+ wall_samples_ms: number[];
+ validated: boolean;
+}
+
+interface BenchState {
+ state: 'boot' | 'running' | 'done' | 'error';
+ params: Record | null;
+ results: BenchResult[];
+ error: string | null;
+ log: string[];
+}
+
+const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] };
+(window as unknown as { __bench: BenchState }).__bench = benchState;
+const resultsClient = makeResultsClient({ page: 'bench-planner' });
+(window as unknown as { __runId: string }).__runId = resultsClient.runId;
+
+async function postFinal(): Promise {
+ await resultsClient.postResults({
+ state: benchState.state, params: benchState.params, results: benchState.results,
+ error: benchState.error, log: benchState.log,
+ userAgent: navigator.userAgent, hardwareConcurrency: navigator.hardwareConcurrency,
+ });
+}
+
+const $log = document.getElementById('log') as HTMLDivElement;
+function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) {
+ const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : '';
+ const span = document.createElement('div');
+ span.className = cls;
+ span.textContent = msg;
+ $log.appendChild(span);
+ benchState.log.push(`[${level}] ${msg}`);
+ console.log(`[bench-planner] ${msg}`);
+}
+
+async function compileOne(device: GPUDevice, code: string, key: string, layout: GPUBindGroupLayout): Promise {
+ const module = device.createShaderModule({ code });
+ const info = await module.getCompilationInfo();
+ let hasError = false;
+ const errLines: string[] = [];
+ for (const m of info.messages) {
+ const line = `[shader ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`;
+ if (m.type === 'error') { console.error(line); log('err', line); errLines.push(line); hasError = true; }
+ else { console.warn(line); }
+ }
+ if (hasError) throw new Error(`WGSL compile failed for ${key}: ${errLines.slice(0, 4).join(' | ')}`);
+ return device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+}
+
+async function readbackU32(device: GPUDevice, buf: GPUBuffer, byteLength: number): Promise {
+ const staging = device.createBuffer({ size: byteLength, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST });
+ const enc = device.createCommandEncoder();
+ enc.copyBufferToBuffer(buf, 0, staging, 0, byteLength);
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await staging.mapAsync(GPUMapMode.READ);
+ const out = new Uint32Array(staging.getMappedRange().slice(0));
+ staging.unmap();
+ staging.destroy();
+ return out;
+}
+
+async function runOne(device: GPUDevice, sm: ShaderManager): Promise {
+ log('info', `=== B=${BUCKETS} λ=${LAMBDA} S=${S} TPB=${TPB} PER=${PER_THREAD} DISP=${DISP} REPS=${REPS}`);
+ if (TPB * PER_THREAD !== BUCKETS) throw new Error(`BUCKETS=${BUCKETS} must equal TPB*PER_THREAD=${TPB * PER_THREAD}`);
+
+ const { counts, offsets } = buildSyntheticCounts(BUCKETS, LAMBDA, 0x5fa11);
+ let totalActive = 0;
+ let cMin = 99999, cMax = 0;
+ for (let b = 0; b < BUCKETS; b++) {
+ totalActive += counts[b];
+ if (counts[b] > cMax) cMax = counts[b];
+ if (counts[b] < cMin) cMin = counts[b];
+ }
+ log('info', `synthetic counts: min=${cMin} max=${cMax} totalActive=${totalActive}`);
+
+ const ref = buildHostReference(counts, offsets, S);
+ log('info', `host reference: totalPairs=${ref.totalPairs} totalCarries=${ref.totalCarries} numChunks=${ref.numChunks}`);
+
+ // Allocate output buffers sized for the host-computed plan.
+ // For a real MSM these sizes would be conservative max bounds; here
+ // we use exact host values for tighter validation.
+ const mkStorage = (bytes: number, copySrc = false, copyDst = false): GPUBuffer => {
+ let usage = GPUBufferUsage.STORAGE;
+ if (copySrc) usage |= GPUBufferUsage.COPY_SRC;
+ if (copyDst) usage |= GPUBufferUsage.COPY_DST;
+ return device.createBuffer({ size: bytes, usage });
+ };
+
+ const countsBuf = mkStorage(counts.byteLength, false, true);
+ const offsetsBuf = mkStorage(offsets.byteLength, false, true);
+ device.queue.writeBuffer(countsBuf, 0, counts);
+ device.queue.writeBuffer(offsetsBuf, 0, offsets);
+
+ const chunkPlanBytes = ref.chunkPlan.byteLength;
+ const scatterPlanBytes = ref.scatterPlan.byteLength;
+ const carryPlanBytes = ref.carryPlan.byteLength;
+ const newCountsBytes = ref.newCounts.byteLength;
+ const newOffsetsBytes = ref.newOffsets.byteLength;
+ const totalsBytes = 16;
+
+ const chunkPlanBuf = mkStorage(chunkPlanBytes, true);
+ const scatterPlanBuf = mkStorage(scatterPlanBytes, true);
+ const carryPlanBuf = mkStorage(carryPlanBytes, true);
+ const newCountsBuf = mkStorage(newCountsBytes, true);
+ const newOffsetsBuf = mkStorage(newOffsetsBytes, true);
+ const totalsBuf = mkStorage(totalsBytes, true);
+
+ const paramsBuf = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST });
+ device.queue.writeBuffer(paramsBuf, 0, new Uint32Array([BUCKETS, S, 0, 0]));
+
+ const layout = device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 5, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 6, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 7, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 8, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+ const pipeline = await compileOne(device, sm.gen_ba_planner_v2_bench_shader(TPB, PER_THREAD, S, 64), `planner-v2-T${TPB}-P${PER_THREAD}-S${S}`, layout);
+ const bind = device.createBindGroup({
+ layout,
+ entries: [
+ { binding: 0, resource: { buffer: countsBuf } },
+ { binding: 1, resource: { buffer: offsetsBuf } },
+ { binding: 2, resource: { buffer: chunkPlanBuf } },
+ { binding: 3, resource: { buffer: scatterPlanBuf } },
+ { binding: 4, resource: { buffer: carryPlanBuf } },
+ { binding: 5, resource: { buffer: newCountsBuf } },
+ { binding: 6, resource: { buffer: newOffsetsBuf } },
+ { binding: 7, resource: { buffer: totalsBuf } },
+ { binding: 8, resource: { buffer: paramsBuf } },
+ ],
+ });
+
+ // Warmup.
+ {
+ const enc = device.createCommandEncoder();
+ const pass = enc.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bind);
+ pass.dispatchWorkgroups(1, 1, 1);
+ pass.end();
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ }
+ log('info', 'warmup done');
+
+ // Validation (one dispatch, read back, compare against host reference).
+ let validated = false;
+ if (VALIDATE) {
+ const gpuTotals = await readbackU32(device, totalsBuf, totalsBytes);
+ const gpuNewCounts = await readbackU32(device, newCountsBuf, newCountsBytes);
+ const gpuNewOffsets = await readbackU32(device, newOffsetsBuf, newOffsetsBytes);
+ const gpuChunkPlan = await readbackU32(device, chunkPlanBuf, chunkPlanBytes);
+ const gpuScatterPlan = await readbackU32(device, scatterPlanBuf, scatterPlanBytes);
+ const gpuCarryPlan = await readbackU32(device, carryPlanBuf, carryPlanBytes);
+
+ const mismatches: string[] = [];
+ if (gpuTotals[0] !== ref.totalPairs) mismatches.push(`totals[0]: gpu=${gpuTotals[0]} ref=${ref.totalPairs}`);
+ if (gpuTotals[1] !== ref.totalCarries) mismatches.push(`totals[1]: gpu=${gpuTotals[1]} ref=${ref.totalCarries}`);
+ if (gpuTotals[2] !== ref.totalNew) mismatches.push(`totals[2]: gpu=${gpuTotals[2]} ref=${ref.totalNew}`);
+ for (let b = 0; b < BUCKETS && mismatches.length < 8; b++) {
+ if (gpuNewCounts[b] !== ref.newCounts[b]) mismatches.push(`newCounts[${b}]: gpu=${gpuNewCounts[b]} ref=${ref.newCounts[b]}`);
+ if (gpuNewOffsets[b] !== ref.newOffsets[b]) mismatches.push(`newOffsets[${b}]: gpu=${gpuNewOffsets[b]} ref=${ref.newOffsets[b]}`);
+ }
+ // chunk_plan/scatter_plan: compare element-wise.
+ let cpFails = 0;
+ for (let i = 0; i < ref.chunkPlan.length; i++) {
+ if (gpuChunkPlan[i] !== ref.chunkPlan[i]) { cpFails++; if (cpFails <= 3) mismatches.push(`chunkPlan[${i}]: gpu=${gpuChunkPlan[i]} ref=${ref.chunkPlan[i]}`); }
+ }
+ let spFails = 0;
+ for (let i = 0; i < ref.scatterPlan.length; i++) {
+ if (gpuScatterPlan[i] !== ref.scatterPlan[i]) { spFails++; if (spFails <= 3) mismatches.push(`scatterPlan[${i}]: gpu=${gpuScatterPlan[i]} ref=${ref.scatterPlan[i]}`); }
+ }
+ let cyFails = 0;
+ for (let i = 0; i < 2 * ref.totalCarries; i++) {
+ if (gpuCarryPlan[i] !== ref.carryPlan[i]) { cyFails++; if (cyFails <= 3) mismatches.push(`carryPlan[${i}]: gpu=${gpuCarryPlan[i]} ref=${ref.carryPlan[i]}`); }
+ }
+ if (mismatches.length === 0 && cpFails === 0 && spFails === 0 && cyFails === 0) {
+ validated = true;
+ log('ok', 'validation: PASS — GPU planner output byte-equivalent to host reference');
+ } else {
+ log('err', `validation: FAIL — ${cpFails} chunkPlan, ${spFails} scatterPlan, ${cyFails} carryPlan mismatches; first few:`);
+ for (const m of mismatches.slice(0, 10)) log('err', ` ${m}`);
+ }
+ }
+
+ // Timed runs: DISP back-to-back planner dispatches in one command encoder.
+ const samples: number[] = [];
+ for (let r = 0; r < REPS; r++) {
+ const enc = device.createCommandEncoder();
+ for (let d = 0; d < DISP; d++) {
+ const pass = enc.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bind);
+ pass.dispatchWorkgroups(1, 1, 1);
+ pass.end();
+ }
+ const t0 = performance.now();
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ samples.push(performance.now() - t0);
+ }
+ const med = median(samples);
+ const mn = Math.min(...samples);
+ const mx = Math.max(...samples);
+ const perDispatchMin = (mn / DISP) * 1000;
+ const perDispatchMed = (med / DISP) * 1000;
+ const perDispatchMax = (mx / DISP) * 1000;
+
+ log(
+ 'ok',
+ `per-planner: min=${perDispatchMin.toFixed(2)}μs median=${perDispatchMed.toFixed(2)}μs max=${perDispatchMax.toFixed(2)}μs` +
+ ` (total wall: min=${mn.toFixed(2)}ms median=${med.toFixed(2)}ms max=${mx.toFixed(2)}ms over DISP=${DISP})`,
+ );
+
+ countsBuf.destroy(); offsetsBuf.destroy(); chunkPlanBuf.destroy(); scatterPlanBuf.destroy();
+ carryPlanBuf.destroy(); newCountsBuf.destroy(); newOffsetsBuf.destroy(); totalsBuf.destroy();
+ paramsBuf.destroy();
+
+ return {
+ buckets: BUCKETS, lambda: LAMBDA, s: S, tpb: TPB, per_thread: PER_THREAD, disp: DISP, reps: REPS,
+ total_pairs: ref.totalPairs, total_carries: ref.totalCarries, num_chunks: ref.numChunks,
+ per_dispatch_us: { min: perDispatchMin, median: perDispatchMed, max: perDispatchMax },
+ wall_samples_ms: samples,
+ validated,
+ };
+}
+
+function parseParams() {
+ const qp = new URLSearchParams(window.location.search);
+ if (qp.get('buckets')) BUCKETS = parseInt(qp.get('buckets')!, 10);
+ if (qp.get('lambda')) LAMBDA = parseInt(qp.get('lambda')!, 10);
+ if (qp.get('s')) S = parseInt(qp.get('s')!, 10);
+ if (qp.get('tpb')) TPB = parseInt(qp.get('tpb')!, 10);
+ if (qp.get('per')) PER_THREAD = parseInt(qp.get('per')!, 10);
+ if (qp.get('disp')) DISP = parseInt(qp.get('disp')!, 10);
+ if (qp.get('reps')) REPS = parseInt(qp.get('reps')!, 10);
+ if (qp.get('validate') === '1') VALIDATE = true;
+ return { buckets: BUCKETS, lambda: LAMBDA, s: S, tpb: TPB, per: PER_THREAD, disp: DISP, reps: REPS, validate: VALIDATE };
+}
+
+async function main() {
+ try {
+ if (!('gpu' in navigator)) throw new Error('navigator.gpu missing');
+ const params = parseParams();
+ benchState.params = params;
+ log('info', `params: ${JSON.stringify(params)}`);
+ benchState.state = 'running';
+ const device = await get_device();
+ log('info', 'WebGPU device acquired');
+ const sm = new ShaderManager(4, BUCKETS, BN254_CURVE_CONFIG, false);
+ const r = await runOne(device, sm);
+ benchState.results.push(r);
+ resultsClient.postProgress({ kind: 'planner_done', per_dispatch_us: r.per_dispatch_us, validated: r.validated });
+ benchState.state = 'done';
+ log('ok', 'done');
+ } catch (e) {
+ const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e);
+ log('err', `FATAL: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ }
+}
+
+main()
+ .catch(e => { const msg = e instanceof Error ? e.message : String(e); log('err', `unhandled: ${msg}`); benchState.state = 'error'; benchState.error = msg; })
+ .finally(() => { postFinal().catch(() => {}); });
diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs
index a9e3eacd368d..4bb7c19a9751 100644
--- a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs
+++ b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs
@@ -127,6 +127,17 @@ if (!TARGETS[argv.target]) {
const pageMap = {
"bench-batch-affine": "/dev/msm-webgpu/bench-batch-affine.html",
+ "bench-fused-wg-scan": "/dev/msm-webgpu/bench-fused-wg-scan.html",
+ "bench-ba-rev-packed-carry": "/dev/msm-webgpu/bench-ba-rev-packed-carry.html",
+ "bench-msm-chain": "/dev/msm-webgpu/bench-msm-chain.html",
+ "bench-ba-pair-disjoint": "/dev/msm-webgpu/bench-ba-pair-disjoint.html",
+ "bench-msm-tree": "/dev/msm-webgpu/bench-msm-tree.html",
+ "bench-msm-tree-v2": "/dev/msm-webgpu/bench-msm-tree-v2.html",
+ "bench-msm-tree-v3": "/dev/msm-webgpu/bench-msm-tree-v3.html",
+ "bench-planner": "/dev/msm-webgpu/bench-planner.html",
+ "bench-csr-to-v2": "/dev/msm-webgpu/bench-csr-to-v2.html",
+ "bench-msm-oracle": "/dev/msm-webgpu/bench-msm-oracle.html",
+ "bench-msm-oracle-prod": "/dev/msm-webgpu/bench-msm-oracle-prod.html",
"bench-smvp-tree": "/dev/msm-webgpu/bench-smvp-tree.html",
sanity: "/dev/msm-webgpu/index.html",
};
diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts
index 6c97ef584582..24bd0a2f4ba1 100644
--- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts
+++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts
@@ -5,11 +5,29 @@ import {
batch_affine_apply as batch_affine_apply_shader,
batch_affine_apply_scatter as batch_affine_apply_scatter_shader,
batch_affine_dispatch_args as batch_affine_dispatch_args_shader,
+ ba_carry_copy_bench as ba_carry_copy_bench_shader,
+ ba_fused_super_bench as ba_fused_super_bench_shader,
+ ba_marshal_chain_bench as ba_marshal_chain_bench_shader,
+ ba_marshal_pairs_bench as ba_marshal_pairs_bench_shader,
+ ba_marshal_tree_l0_bench as ba_marshal_tree_l0_bench_shader,
+ ba_pair_disjoint_bench as ba_pair_disjoint_bench_shader,
+ ba_pair_disjoint_tree_bench as ba_pair_disjoint_tree_bench_shader,
+ ba_planner_bench as ba_planner_bench_shader,
+ ba_planner_v2_bench as ba_planner_v2_bench_shader,
+ ba_planner_v2_prod as ba_planner_v2_prod_shader,
+ ba_marshal_pairs_prod as ba_marshal_pairs_prod_shader,
+ ba_pair_disjoint_tree_prod as ba_pair_disjoint_tree_prod_shader,
+ ba_scatter_pairs_prod as ba_scatter_pairs_prod_shader,
+ ba_carry_copy_prod as ba_carry_copy_prod_shader,
+ ba_scatter_pairs_bench as ba_scatter_pairs_bench_shader,
+ ba_tail_reduce_bench as ba_tail_reduce_bench_shader,
+ ba_rev_packed_carry_bench as ba_rev_packed_carry_bench_shader,
bench_batch_affine as bench_batch_affine_shader,
batch_affine_finalize as batch_affine_finalize_shader,
batch_affine_finalize_apply as batch_affine_finalize_apply_shader,
batch_affine_finalize_collect as batch_affine_finalize_collect_shader,
batch_affine_fused_revcarry as batch_affine_fused_revcarry_shader,
+ batch_affine_fused_wg_scan as batch_affine_fused_wg_scan_shader,
batch_affine_init as batch_affine_init_shader,
batch_affine_schedule as batch_affine_schedule_shader,
batch_inverse as batch_inverse_shader,
@@ -27,6 +45,9 @@ import {
bpr_bn254 as bpr_bn254_shader,
convert_point_coords_and_decompose_scalars,
convert_points_only as convert_points_only_shader,
+ csr_to_v2_active_sums as csr_to_v2_active_sums_shader,
+ csr_to_v2_meta as csr_to_v2_meta_shader,
+ v2_to_running as v2_to_running_shader,
decompose_scalars_signed_only as decompose_scalars_signed_only_shader,
decompress_g1_bn254 as decompress_g1_bn254_shader,
divsteps_bench as divsteps_bench_shader,
@@ -42,6 +63,7 @@ import {
mont_pro_product_f32_22_sos3uv3 as montgomery_product_f32_22_sos3uv3_funcs,
mont_pro_product_karat_yuval as montgomery_product_karat_yuval_funcs,
mulhilo_22 as mulhilo_22_funcs,
+ packed_field as packed_field_funcs,
smvp_bn254 as smvp_bn254_shader,
smvp_tree_entry_bucket_id as smvp_tree_entry_bucket_id_shader,
smvp_tree_phase1 as smvp_tree_phase1_shader,
@@ -762,6 +784,491 @@ ${packLines.join('\n')}
);
}
+ /**
+ * Marshal kernel for the bench-msm-chain pipeline: transposes a
+ * CSR-sorted point list into the strided SoA layout the
+ * ba_rev_packed_carry chain kernel consumes. Pure memory shuffle.
+ */
+ public gen_ba_marshal_chain_shader(workgroup_size: number, s: number): string {
+ if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) {
+ throw new Error(`gen_ba_marshal_chain_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`);
+ }
+ return mustache.render(
+ ba_marshal_chain_bench_shader,
+ { workgroup_size, s, num_words: this.num_words, recompile: this.recompile },
+ { structs },
+ );
+ }
+
+ /**
+ * Fused super-kernel for the v3 pipeline: combines marshal + disjoint
+ * + scatter into one kernel. Per chunk-thread: reads chunk_plan +
+ * scatter_plan, gathers from active_sums_old, computes batched-
+ * inverse pair sums in registers, writes directly to active_sums_new.
+ * Carry is a separate kernel.
+ */
+ public gen_ba_fused_super_bench_shader(workgroup_size: number, s: number): string {
+ if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) {
+ throw new Error(`gen_ba_fused_super_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`);
+ }
+ const dec = this.decoupledPackUnpackWgsl();
+ return mustache.render(
+ ba_fused_super_bench_shader,
+ {
+ workgroup_size, s,
+ word_size: this.word_size, num_words: this.num_words, n0: this.n0,
+ p_limbs: this.p_limbs, r_limbs: this.r_limbs, r_cubed_limbs: this.r_cubed_limbs,
+ p_minus_2_limbs: this.p_minus_2_limbs, mask: this.mask,
+ two_pow_word_size: this.two_pow_word_size, p_inv_mod_2w: this.p_inv_mod_2w,
+ p_inv_by_a_lo: this.p_inv_by_a_lo,
+ dec_unpack: dec.unpack, dec_pack: dec.pack, recompile: this.recompile,
+ },
+ {
+ structs, bigint_funcs,
+ montgomery_product_funcs: this.mont_product_src,
+ field_funcs, fr_pow_funcs, bigint_by_funcs, by_inverse_a_funcs,
+ },
+ );
+ }
+
+ /**
+ * v2 GPU planner: single-kernel scan + scatter. One workgroup of TPB
+ * threads handles all B buckets via per-thread local scan + workgroup-
+ * wide Hillis-Steele scan + per-thread scatter. No atomics, no host
+ * sync, single dispatch. Scales to B <= TPB * PER_THREAD within one
+ * workgroup (e.g. 256 * 32 = 8192 buckets).
+ */
+ public gen_ba_planner_v2_bench_shader(workgroup_size: number, per_thread: number, s: number, pair_cap: number = 64): string {
+ if (workgroup_size <= 0 || per_thread <= 0 || s <= 0 || pair_cap <= 0 ||
+ !Number.isInteger(workgroup_size) || !Number.isInteger(per_thread) || !Number.isInteger(s) || !Number.isInteger(pair_cap)) {
+ throw new Error(`gen_ba_planner_v2_bench_shader: positive integer args required`);
+ }
+ return mustache.render(
+ ba_planner_v2_bench_shader,
+ { workgroup_size, per_thread, pair_cap, s, num_words: this.num_words, recompile: this.recompile },
+ { structs },
+ );
+ }
+
+ /**
+ * GPU-side bin-packing planner for the v3 pipeline. One thread per
+ * bucket; atomicAdd reserves global per-pair / per-carry / per-new-
+ * slot offsets; the thread writes chunk_plan + scatter_plan +
+ * carry_plan + new_counts + new_offsets for its bucket. Pair-count
+ * loop bounded by compile-time `pair_cap` (defaults to 64 — covers
+ * Poisson(λ=32) max-count without issue).
+ */
+ public gen_ba_planner_bench_shader(workgroup_size: number, s: number, pair_cap: number = 64): string {
+ if (workgroup_size <= 0 || s <= 0 || pair_cap <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s) || !Number.isInteger(pair_cap)) {
+ throw new Error(`gen_ba_planner_bench_shader: workgroup_size (${workgroup_size}), s (${s}), and pair_cap (${pair_cap}) must be positive integers`);
+ }
+ return mustache.render(
+ ba_planner_bench_shader,
+ { workgroup_size, s, pair_cap, num_words: this.num_words, recompile: this.recompile },
+ { structs },
+ );
+ }
+
+ /**
+ * Bin-packed pair-tree: marshal kernel that gathers operands from a
+ * generic active_sums buffer per chunk_plan. Works at any level
+ * (L0 active_sums = bucket-sorted point pool, L1+ = previous
+ * level's pair-sums + carries).
+ */
+ public gen_ba_marshal_pairs_bench_shader(workgroup_size: number, s: number): string {
+ if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) {
+ throw new Error(`gen_ba_marshal_pairs_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`);
+ }
+ return mustache.render(
+ ba_marshal_pairs_bench_shader,
+ { workgroup_size, s, num_words: this.num_words, recompile: this.recompile },
+ { structs },
+ );
+ }
+
+ /**
+ * Bin-packed pair-tree: scatter kernel that places the disjoint
+ * kernel's strided outputs at per-bucket destinations in
+ * active_sums_new per scatter_plan.
+ */
+ public gen_ba_scatter_pairs_bench_shader(workgroup_size: number, s: number): string {
+ if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) {
+ throw new Error(`gen_ba_scatter_pairs_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`);
+ }
+ return mustache.render(
+ ba_scatter_pairs_bench_shader,
+ { workgroup_size, s, num_words: this.num_words, recompile: this.recompile },
+ { structs },
+ );
+ }
+
+ /**
+ * Layout converter (active_sums materialization): copies packed
+ * 8×u32 base coords from cached_bases into bucket-major active_sums
+ * indexed by val_idx. One thread per (subtask, slot). Pure raw vec4
+ * copy — no field-element math.
+ */
+ public gen_csr_to_v2_active_sums_shader(workgroup_size: number): string {
+ if (workgroup_size <= 0 || !Number.isInteger(workgroup_size)) {
+ throw new Error(`gen_csr_to_v2_active_sums_shader: workgroup_size (${workgroup_size}) must be a positive integer`);
+ }
+ return mustache.render(
+ csr_to_v2_active_sums_shader,
+ { workgroup_size, recompile: this.recompile },
+ {},
+ );
+ }
+
+ /**
+ * Production v2 GPU planner. Same algorithm as
+ * gen_ba_planner_v2_bench_shader, additionally emits per-level
+ * dispatch_args triples into totals[4..6] (marshal/disjoint/scatter)
+ * and totals[7..9] (carry) so the host orchestrator can drive the
+ * four downstream prod kernels via dispatchWorkgroupsIndirect with
+ * zero pad-chunk waste. wgi must match the workgroup size used to
+ * compile the four downstream prod kernels.
+ */
+ public gen_ba_planner_v2_prod_shader(workgroup_size: number, per_thread: number, s: number, wgi: number, pair_cap: number = 64): string {
+ if (workgroup_size <= 0 || per_thread <= 0 || s <= 0 || wgi <= 0 || pair_cap <= 0 ||
+ !Number.isInteger(workgroup_size) || !Number.isInteger(per_thread) || !Number.isInteger(s) || !Number.isInteger(wgi) || !Number.isInteger(pair_cap)) {
+ throw new Error(`gen_ba_planner_v2_prod_shader: positive integer args required`);
+ }
+ return mustache.render(
+ ba_planner_v2_prod_shader,
+ { workgroup_size, per_thread, pair_cap, s, wgi, num_words: this.num_words, recompile: this.recompile },
+ { structs },
+ );
+ }
+
+ /**
+ * Marshal pairs — prod variant. Reads num_chunks from
+ * totals[3] (storage), dispatched indirectly off totals[4..6].
+ */
+ public gen_ba_marshal_pairs_prod_shader(workgroup_size: number, s: number): string {
+ if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) {
+ throw new Error(`gen_ba_marshal_pairs_prod_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`);
+ }
+ return mustache.render(
+ ba_marshal_pairs_prod_shader,
+ { workgroup_size, s, num_words: this.num_words, recompile: this.recompile },
+ { structs },
+ );
+ }
+
+ /**
+ * Disjoint pair-sum tree — prod variant. Reads T_curr from
+ * totals[3] (storage), always uses the final-mode strided write.
+ */
+ public gen_ba_pair_disjoint_tree_prod_shader(workgroup_size: number, s: number): string {
+ if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) {
+ throw new Error(`gen_ba_pair_disjoint_tree_prod_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`);
+ }
+ const dec = this.decoupledPackUnpackWgsl();
+ return mustache.render(
+ ba_pair_disjoint_tree_prod_shader,
+ {
+ workgroup_size,
+ s,
+ word_size: this.word_size,
+ num_words: this.num_words,
+ n0: this.n0,
+ p_limbs: this.p_limbs,
+ r_limbs: this.r_limbs,
+ r_cubed_limbs: this.r_cubed_limbs,
+ p_minus_2_limbs: this.p_minus_2_limbs,
+ mask: this.mask,
+ two_pow_word_size: this.two_pow_word_size,
+ p_inv_mod_2w: this.p_inv_mod_2w,
+ p_inv_by_a_lo: this.p_inv_by_a_lo,
+ dec_unpack: dec.unpack,
+ dec_pack: dec.pack,
+ recompile: this.recompile,
+ },
+ {
+ structs,
+ bigint_funcs,
+ montgomery_product_funcs: this.mont_product_src,
+ field_funcs,
+ fr_pow_funcs,
+ bigint_by_funcs,
+ by_inverse_a_funcs,
+ },
+ );
+ }
+
+ /**
+ * Scatter pairs — prod variant. Reads T from totals[3] (storage).
+ */
+ public gen_ba_scatter_pairs_prod_shader(workgroup_size: number, s: number): string {
+ if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) {
+ throw new Error(`gen_ba_scatter_pairs_prod_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`);
+ }
+ return mustache.render(
+ ba_scatter_pairs_prod_shader,
+ { workgroup_size, s, num_words: this.num_words, recompile: this.recompile },
+ { structs },
+ );
+ }
+
+ /**
+ * Carry copy — prod variant. Reads num_carries from totals[1] (storage).
+ */
+ public gen_ba_carry_copy_prod_shader(workgroup_size: number): string {
+ if (workgroup_size <= 0 || !Number.isInteger(workgroup_size)) {
+ throw new Error(`gen_ba_carry_copy_prod_shader: workgroup_size (${workgroup_size}) must be a positive integer`);
+ }
+ return mustache.render(
+ ba_carry_copy_prod_shader,
+ { workgroup_size, num_words: this.num_words, recompile: this.recompile },
+ { structs },
+ );
+ }
+
+ /**
+ * Boundary adapter for the v2 pair-tree -> production finalize: copies
+ * the per-bucket reduced packed point out of the v2 combined-SoA
+ * active_sums buffer into the production running_x / running_y layout
+ * and writes bucket_active. One thread per (subtask, bucket_local);
+ * the caller binds running_x / running_y / bucket_active sub-views
+ * offset by subtask_idx * num_columns.
+ */
+ public gen_v2_to_running_shader(workgroup_size: number): string {
+ if (workgroup_size <= 0 || !Number.isInteger(workgroup_size)) {
+ throw new Error(`gen_v2_to_running_shader: workgroup_size (${workgroup_size}) must be a positive integer`);
+ }
+ return mustache.render(
+ v2_to_running_shader,
+ { workgroup_size, recompile: this.recompile },
+ {},
+ );
+ }
+
+ /**
+ * Layout converter (meta derivation): writes per-bucket count and
+ * subtask-relative offset from cuZK row_ptr. One thread per
+ * (subtask, bucket).
+ */
+ public gen_csr_to_v2_meta_shader(workgroup_size: number): string {
+ if (workgroup_size <= 0 || !Number.isInteger(workgroup_size)) {
+ throw new Error(`gen_csr_to_v2_meta_shader: workgroup_size (${workgroup_size}) must be a positive integer`);
+ }
+ return mustache.render(
+ csr_to_v2_meta_shader,
+ { workgroup_size, recompile: this.recompile },
+ {},
+ );
+ }
+
+ /**
+ * Bin-packed pair-tree: carry-copy kernel. Propagates the odd-count
+ * carry element forward to the next level without modification.
+ * Pure memory shuffle.
+ */
+ public gen_ba_carry_copy_bench_shader(workgroup_size: number): string {
+ if (workgroup_size <= 0 || !Number.isInteger(workgroup_size)) {
+ throw new Error(`gen_ba_carry_copy_bench_shader: workgroup_size (${workgroup_size}) must be a positive integer`);
+ }
+ return mustache.render(
+ ba_carry_copy_bench_shader,
+ { workgroup_size, num_words: this.num_words, recompile: this.recompile },
+ { structs },
+ );
+ }
+
+ /**
+ * Tree variant of the disjoint pair-sum kernel: writes outputs in the
+ * layout the next pair-tree level expects as input, so multi-level
+ * reductions can chain without an intervening marshal pass. Per
+ * thread: 2*S inputs -> S disjoint pair sums. Final-level flag (via
+ * params.z) switches to a simple strided write for the last pass.
+ */
+ public gen_ba_pair_disjoint_tree_bench_shader(workgroup_size: number, s: number): string {
+ if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) {
+ throw new Error(`gen_ba_pair_disjoint_tree_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`);
+ }
+ const dec = this.decoupledPackUnpackWgsl();
+ return mustache.render(
+ ba_pair_disjoint_tree_bench_shader,
+ {
+ workgroup_size,
+ s,
+ word_size: this.word_size,
+ num_words: this.num_words,
+ n0: this.n0,
+ p_limbs: this.p_limbs,
+ r_limbs: this.r_limbs,
+ r_cubed_limbs: this.r_cubed_limbs,
+ p_minus_2_limbs: this.p_minus_2_limbs,
+ mask: this.mask,
+ two_pow_word_size: this.two_pow_word_size,
+ p_inv_mod_2w: this.p_inv_mod_2w,
+ p_inv_by_a_lo: this.p_inv_by_a_lo,
+ dec_unpack: dec.unpack,
+ dec_pack: dec.pack,
+ recompile: this.recompile,
+ },
+ {
+ structs,
+ bigint_funcs,
+ montgomery_product_funcs: this.mont_product_src,
+ field_funcs,
+ fr_pow_funcs,
+ bigint_by_funcs,
+ by_inverse_a_funcs,
+ },
+ );
+ }
+
+ /**
+ * Level-0 marshal kernel for the bench-msm-tree pipeline: transposes
+ * CSR-sorted point indices into the 2-plane strided SoA layout the
+ * tree-disjoint kernel reads. Pure memory shuffle.
+ */
+ public gen_ba_marshal_tree_l0_bench_shader(workgroup_size: number, s: number): string {
+ if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) {
+ throw new Error(`gen_ba_marshal_tree_l0_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`);
+ }
+ return mustache.render(
+ ba_marshal_tree_l0_bench_shader,
+ { workgroup_size, s, num_words: this.num_words, recompile: this.recompile },
+ { structs },
+ );
+ }
+
+ /**
+ * Tail kernel for the bench-msm-tree pipeline: one thread per small
+ * bucket (count < 2*S), serial per-thread chain with one fr_inv_by_a
+ * per add (no batched inversion). Bounded loop up to compile-time
+ * TAIL_CAP = 2*S - 1.
+ */
+ public gen_ba_tail_reduce_bench_shader(workgroup_size: number, s: number): string {
+ if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) {
+ throw new Error(`gen_ba_tail_reduce_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`);
+ }
+ const tail_cap = 2 * s - 1;
+ const dec = this.decoupledPackUnpackWgsl();
+ return mustache.render(
+ ba_tail_reduce_bench_shader,
+ {
+ workgroup_size,
+ tail_cap,
+ word_size: this.word_size,
+ num_words: this.num_words,
+ n0: this.n0,
+ p_limbs: this.p_limbs,
+ r_limbs: this.r_limbs,
+ r_cubed_limbs: this.r_cubed_limbs,
+ p_minus_2_limbs: this.p_minus_2_limbs,
+ mask: this.mask,
+ two_pow_word_size: this.two_pow_word_size,
+ p_inv_mod_2w: this.p_inv_mod_2w,
+ p_inv_by_a_lo: this.p_inv_by_a_lo,
+ dec_unpack: dec.unpack,
+ dec_pack: dec.pack,
+ recompile: this.recompile,
+ },
+ {
+ structs,
+ bigint_funcs,
+ montgomery_product_funcs: this.mont_product_src,
+ field_funcs,
+ fr_pow_funcs,
+ bigint_by_funcs,
+ by_inverse_a_funcs,
+ },
+ );
+ }
+
+ /**
+ * Standalone microbench for the disjoint pair-sum kernel: each
+ * thread reduces 2*S input points to S disjoint pair sums R_k =
+ * P_{2k} + P_{2k+1}, using one batched fr_inv_by_a per chunk of S.
+ * Reclaims the 50% kernel-efficiency loss inherent in
+ * ba_rev_packed_carry (which produces S overlapping pair sums of
+ * which only S/2 are usable for pair-tree reduction).
+ */
+ public gen_ba_pair_disjoint_bench_shader(workgroup_size: number, s: number): string {
+ if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) {
+ throw new Error(`gen_ba_pair_disjoint_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`);
+ }
+ const dec = this.decoupledPackUnpackWgsl();
+ return mustache.render(
+ ba_pair_disjoint_bench_shader,
+ {
+ workgroup_size,
+ s,
+ word_size: this.word_size,
+ num_words: this.num_words,
+ n0: this.n0,
+ p_limbs: this.p_limbs,
+ r_limbs: this.r_limbs,
+ r_cubed_limbs: this.r_cubed_limbs,
+ p_minus_2_limbs: this.p_minus_2_limbs,
+ mask: this.mask,
+ two_pow_word_size: this.two_pow_word_size,
+ p_inv_mod_2w: this.p_inv_mod_2w,
+ p_inv_by_a_lo: this.p_inv_by_a_lo,
+ dec_unpack: dec.unpack,
+ dec_pack: dec.pack,
+ recompile: this.recompile,
+ },
+ {
+ structs,
+ bigint_funcs,
+ montgomery_product_funcs: this.mont_product_src,
+ field_funcs,
+ fr_pow_funcs,
+ bigint_by_funcs,
+ by_inverse_a_funcs,
+ },
+ );
+ }
+
+ /**
+ * Standalone microbench for the recovered ba_rev_packed_carry kernel:
+ * SoA-packed 8x u32 storage across 4 input planes (A.x, A.y, P.x, P.y),
+ * strided per-thread access (e = t + i*T), forward prefix-product +
+ * single fr_inv_by_a + backward peel with resident-accumulator
+ * load-carry (A_{i+1} := P_i). The canonical kernel that hit ~22
+ * ns/pair on M2 / Chrome 148.
+ */
+ public gen_ba_rev_packed_carry_bench_shader(workgroup_size: number, s: number): string {
+ if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) {
+ throw new Error(`gen_ba_rev_packed_carry_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`);
+ }
+ const dec = this.decoupledPackUnpackWgsl();
+ return mustache.render(
+ ba_rev_packed_carry_bench_shader,
+ {
+ workgroup_size,
+ s,
+ word_size: this.word_size,
+ num_words: this.num_words,
+ n0: this.n0,
+ p_limbs: this.p_limbs,
+ r_limbs: this.r_limbs,
+ r_cubed_limbs: this.r_cubed_limbs,
+ p_minus_2_limbs: this.p_minus_2_limbs,
+ mask: this.mask,
+ two_pow_word_size: this.two_pow_word_size,
+ p_inv_mod_2w: this.p_inv_mod_2w,
+ p_inv_by_a_lo: this.p_inv_by_a_lo,
+ dec_unpack: dec.unpack,
+ dec_pack: dec.pack,
+ recompile: this.recompile,
+ },
+ {
+ structs,
+ bigint_funcs,
+ montgomery_product_funcs: this.mont_product_src,
+ field_funcs,
+ fr_pow_funcs,
+ bigint_by_funcs,
+ by_inverse_a_funcs,
+ },
+ );
+ }
+
public gen_batch_affine_init_shader(workgroup_size: number, packed = false): string {
const dec = this.decoupledPackUnpackWgsl();
return mustache.render(
@@ -877,6 +1384,57 @@ ${packLines.join('\n')}
);
}
+ /**
+ * Workgroup-scan fused batch-affine round kernel (v2). Mirrors the
+ * `bench_batch_affine` design (TPB threads cooperating over a
+ * BATCH_SIZE=TPB*BS pair slice with one fr_inv_by_a per workgroup)
+ * but with bucket-indirect loads via pair_target_meta. Every
+ * field-element variable is `PackedField`; no per-load unpack at
+ * the kernel level.
+ */
+ public gen_batch_affine_fused_wg_scan_shader(tpb: number, bs: number): string {
+ if (tpb <= 0 || bs <= 0 || !Number.isInteger(tpb) || !Number.isInteger(bs)) {
+ throw new Error(`gen_batch_affine_fused_wg_scan_shader: tpb (${tpb}) and bs (${bs}) must be positive integers`);
+ }
+ if ((tpb & (tpb - 1)) !== 0) {
+ throw new Error(`gen_batch_affine_fused_wg_scan_shader: tpb (${tpb}) must be a power of two (Hillis-Steele scan)`);
+ }
+ const batch_size = tpb * bs;
+ const dec = this.decoupledPackUnpackWgsl();
+ return mustache.render(
+ batch_affine_fused_wg_scan_shader,
+ {
+ tpb,
+ bs,
+ batch_size,
+ word_size: this.word_size,
+ num_words: this.num_words,
+ n0: this.n0,
+ p_limbs: this.p_limbs,
+ r_limbs: this.r_limbs,
+ r_cubed_limbs: this.r_cubed_limbs,
+ p_minus_2_limbs: this.p_minus_2_limbs,
+ mask: this.mask,
+ two_pow_word_size: this.two_pow_word_size,
+ p_inv_mod_2w: this.p_inv_mod_2w,
+ p_inv_by_a_lo: this.p_inv_by_a_lo,
+ dec_unpack: dec.unpack,
+ dec_pack: dec.pack,
+ recompile: this.recompile,
+ },
+ {
+ structs,
+ bigint_funcs,
+ montgomery_product_funcs: this.mont_product_src,
+ field_funcs,
+ fr_pow_funcs,
+ bigint_by_funcs,
+ by_inverse_a_funcs,
+ packed_field_funcs,
+ },
+ );
+ }
+
public gen_batch_affine_finalize_shader(workgroup_size: number, num_csr_cols: number): string {
return mustache.render(
batch_affine_finalize_shader,
diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts b/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts
new file mode 100644
index 000000000000..51ae0e363a34
--- /dev/null
+++ b/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts
@@ -0,0 +1,601 @@
+///
+
+/**
+ * v2 bin-packed pair-tree MSM bucket-accumulate orchestrator.
+ *
+ * Drop-in replacement for the cuZK round-loop (`smvp_batch_affine_gpu`'s
+ * schedule + batch_inverse_parallel + apply_scatter per round) for the
+ * Pippenger bucket-accumulate phase. For each window:
+ *
+ * csr_to_v2_meta row_ptr -> per-bucket count + offset
+ * csr_to_v2_active_sums val_idx + cached bases -> bucket-major
+ * active_sums (combined SoA, packed 8x u32)
+ * for level in 0..max_levels:
+ * ba_planner_v2_prod counts/offsets -> chunk_plan / scatter_plan
+ * / carry_plan + new_counts/new_offsets +
+ * totals (incl. per-level dispatch_args
+ * triples at totals[4..6] and totals[7..9])
+ * ba_marshal_pairs_prod indirect dispatch off totals[4..6]
+ * ba_pair_disjoint_tree_prod indirect dispatch off totals[4..6]
+ * ba_scatter_pairs_prod indirect dispatch off totals[4..6]
+ * ba_carry_copy_prod indirect dispatch off totals[7..9]
+ * v2_to_running final active_sums slot per non-empty
+ * bucket -> running_x / running_y /
+ * bucket_active in production layout
+ *
+ * The prod kernels read num_chunks (= ceil(total_pairs / S)) and
+ * num_carries from the planner's totals storage output; the host
+ * dispatches them via dispatchWorkgroupsIndirect. Each level's
+ * downstream dispatch is sized to EXACTLY the chunks/carries the
+ * planner produced — zero pad-chunk waste, the runtime advantage
+ * over the pad-fill alternative.
+ *
+ * All dispatches for all windows are recorded onto one command encoder
+ * and submitted once. Submit overhead is paid once per MSM, not once
+ * per window or once per level.
+ *
+ * Layout boundaries:
+ * active_sums (combined SoA, one buffer per ping-pong copy):
+ * plane 0 (x) at vec4 indices [0, PG * M)
+ * plane 1 (y) at vec4 indices [PG * M, 2 * PG * M)
+ * element layout: PG=2 vec4 at [PG*elem, PG*elem+1]
+ * M = input_size + 2 (last 2 slots are the pad pair the planner
+ * emits into the chunk-tail for filler pairs — we initialise the
+ * pad pair once at orchestrator start with distinct-x Montgomery-
+ * form values so the disjoint kernel's lean affine add is well-
+ * defined on pad chunks even though they get scattered to the
+ * discard slot)
+ * running_x / running_y (production, separate buffers):
+ * packed 8x u32 = 2 vec4 per (subtask, bucket_local), at
+ * [PG * bucket_global, PG * bucket_global + 1] with
+ * bucket_global = subtask_idx * num_columns + bucket_local
+ * bucket_active: u32 per bucket_global
+ * v2_to_running binds running_x / running_y / bucket_active with a
+ * subtask_idx * num_columns byte offset so a single per-window
+ * dispatch lands the result at the right slab.
+ */
+
+import { ShaderManager } from './shader_manager.js';
+
+const PG = 2;
+const PG_VEC4_BYTES = 16;
+const ELEMENT_BYTES = PG * PG_VEC4_BYTES;
+
+export interface SmvpV2PairTreeOptions {
+ device: GPUDevice;
+ shaderManager: ShaderManager;
+ num_subtasks: number;
+ num_columns: number;
+ input_size: number;
+
+ s?: number;
+ tpb?: number;
+ per_thread?: number;
+ wgi?: number;
+ max_levels?: number;
+
+ val_idx_buf: GPUBuffer;
+ row_ptr_buf: GPUBuffer;
+ point_x_buf: GPUBuffer;
+ point_y_buf: GPUBuffer;
+
+ running_x_buf: GPUBuffer;
+ running_y_buf: GPUBuffer;
+ bucket_active_buf: GPUBuffer;
+}
+
+export interface SmvpV2PairTreeStats {
+ levels_per_window: number;
+ num_subtasks: number;
+ num_columns: number;
+ total_passes: number;
+ gpu_wall_ms: number;
+}
+
+function roStorageEntry(binding: number): GPUBindGroupLayoutEntry {
+ return { binding, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } };
+}
+function rwStorageEntry(binding: number): GPUBindGroupLayoutEntry {
+ return { binding, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } };
+}
+function uniformEntry(binding: number): GPUBindGroupLayoutEntry {
+ return { binding, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } };
+}
+
+async function compilePipeline(
+ device: GPUDevice,
+ layout: GPUBindGroupLayout,
+ code: string,
+ key: string,
+): Promise {
+ const module = device.createShaderModule({ code });
+ const info = await module.getCompilationInfo();
+ const errs: string[] = [];
+ for (const m of info.messages) {
+ const line = `[smvp-v2 ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`;
+ if (m.type === 'error') {
+ console.error(line);
+ errs.push(line);
+ } else {
+ console.warn(line);
+ }
+ }
+ if (errs.length > 0) {
+ throw new Error(`WGSL compile failed for ${key}: ${errs.slice(0, 4).join(' | ')}`);
+ }
+ return device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+}
+
+interface Pipelines {
+ csrMeta: GPUComputePipeline;
+ csrActive: GPUComputePipeline;
+ planner: GPUComputePipeline;
+ marshal: GPUComputePipeline;
+ disjoint: GPUComputePipeline;
+ scatter: GPUComputePipeline;
+ carry: GPUComputePipeline;
+ v2ToRunning: GPUComputePipeline;
+ layouts: {
+ meta: GPUBindGroupLayout;
+ active: GPUBindGroupLayout;
+ planner: GPUBindGroupLayout;
+ marshal: GPUBindGroupLayout;
+ disjoint: GPUBindGroupLayout;
+ scatter: GPUBindGroupLayout;
+ carry: GPUBindGroupLayout;
+ v2Run: GPUBindGroupLayout;
+ };
+}
+
+async function compileAll(
+ device: GPUDevice,
+ sm: ShaderManager,
+ wgi: number,
+ s: number,
+ tpb: number,
+ per_thread: number,
+): Promise {
+ const layouts: Pipelines['layouts'] = {
+ meta: device.createBindGroupLayout({
+ entries: [roStorageEntry(0), rwStorageEntry(1), rwStorageEntry(2), uniformEntry(3)],
+ }),
+ active: device.createBindGroupLayout({
+ entries: [
+ roStorageEntry(0),
+ roStorageEntry(1),
+ roStorageEntry(2),
+ rwStorageEntry(3),
+ uniformEntry(4),
+ ],
+ }),
+ planner: device.createBindGroupLayout({
+ entries: [
+ roStorageEntry(0),
+ roStorageEntry(1),
+ rwStorageEntry(2),
+ rwStorageEntry(3),
+ rwStorageEntry(4),
+ rwStorageEntry(5),
+ rwStorageEntry(6),
+ rwStorageEntry(7),
+ uniformEntry(8),
+ ],
+ }),
+ marshal: device.createBindGroupLayout({
+ entries: [roStorageEntry(0), roStorageEntry(1), rwStorageEntry(2), roStorageEntry(3), uniformEntry(4)],
+ }),
+ disjoint: device.createBindGroupLayout({
+ entries: [roStorageEntry(0), roStorageEntry(1), rwStorageEntry(2), roStorageEntry(3)],
+ }),
+ scatter: device.createBindGroupLayout({
+ entries: [roStorageEntry(0), roStorageEntry(1), rwStorageEntry(2), roStorageEntry(3), uniformEntry(4)],
+ }),
+ carry: device.createBindGroupLayout({
+ entries: [roStorageEntry(0), roStorageEntry(1), rwStorageEntry(2), roStorageEntry(3), uniformEntry(4)],
+ }),
+ v2Run: device.createBindGroupLayout({
+ entries: [
+ roStorageEntry(0),
+ roStorageEntry(1),
+ roStorageEntry(2),
+ rwStorageEntry(3),
+ rwStorageEntry(4),
+ rwStorageEntry(5),
+ uniformEntry(6),
+ ],
+ }),
+ };
+
+ const [csrMeta, csrActive, planner, marshal, disjoint, scatter, carry, v2ToRunning] = await Promise.all([
+ compilePipeline(device, layouts.meta, sm.gen_csr_to_v2_meta_shader(wgi), `csr-meta-wg${wgi}`),
+ compilePipeline(device, layouts.active, sm.gen_csr_to_v2_active_sums_shader(wgi), `csr-active-wg${wgi}`),
+ compilePipeline(
+ device,
+ layouts.planner,
+ sm.gen_ba_planner_v2_prod_shader(tpb, per_thread, s, wgi, 64),
+ `planner-v2-prod-T${tpb}-P${per_thread}-S${s}-W${wgi}`,
+ ),
+ compilePipeline(device, layouts.marshal, sm.gen_ba_marshal_pairs_prod_shader(wgi, s), `marshal-prod-W${wgi}-S${s}`),
+ compilePipeline(device, layouts.disjoint, sm.gen_ba_pair_disjoint_tree_prod_shader(wgi, s), `disjoint-prod-W${wgi}-S${s}`),
+ compilePipeline(device, layouts.scatter, sm.gen_ba_scatter_pairs_prod_shader(wgi, s), `scatter-prod-W${wgi}-S${s}`),
+ compilePipeline(device, layouts.carry, sm.gen_ba_carry_copy_prod_shader(wgi), `carry-prod-W${wgi}`),
+ compilePipeline(device, layouts.v2Run, sm.gen_v2_to_running_shader(wgi), `v2-to-running-wg${wgi}`),
+ ]);
+ return { csrMeta, csrActive, planner, marshal, disjoint, scatter, carry, v2ToRunning, layouts };
+}
+
+interface Scratch {
+ activeA: GPUBuffer;
+ activeB: GPUBuffer;
+ chainBuf: GPUBuffer;
+ tempOut: GPUBuffer;
+ countsA: GPUBuffer;
+ countsB: GPUBuffer;
+ offsetsA: GPUBuffer;
+ offsetsB: GPUBuffer;
+ perLevelChunkPlan: GPUBuffer[];
+ perLevelScatterPlan: GPUBuffer[];
+ perLevelCarryPlan: GPUBuffer[];
+ perLevelTotals: GPUBuffer[];
+ metaParams: GPUBuffer;
+ activeParams: GPUBuffer;
+ plannerParams: GPUBuffer;
+ marshalConsts: GPUBuffer;
+ scatterConsts: GPUBuffer;
+ carryConsts: GPUBuffer;
+ v2RunParams: GPUBuffer;
+ M: number;
+ maxChunks: number;
+ perThread: number;
+}
+
+function allocScratch(
+ device: GPUDevice,
+ num_columns: number,
+ input_size: number,
+ s: number,
+ max_levels: number,
+ tpb: number,
+ per_thread: number,
+): Scratch {
+ // M = real slots (input_size) + 3 reserved tail slots: pad_left,
+ // pad_right, discard. Reserved slots aren't touched by the converter
+ // or by real bucket reductions (which only address [0, total_actives) <
+ // input_size), so the planner can safely pad-fill the last partial
+ // chunk of chunk_plan / scatter_plan with these constant indices.
+ const M = input_size + 3;
+ const maxChunks = Math.max(1, Math.ceil(input_size / 2 / s) + 1);
+
+ const mk = (bytes: number, extra: GPUBufferUsageFlags = 0): GPUBuffer =>
+ device.createBuffer({ size: bytes, usage: GPUBufferUsage.STORAGE | extra });
+
+ const activeBytes = 2 * PG * M * PG_VEC4_BYTES;
+ const activeA = mk(activeBytes, GPUBufferUsage.COPY_DST);
+ const activeB = mk(activeBytes, GPUBufferUsage.COPY_DST);
+
+ const chainBuf = mk(2 * PG * (2 * s * maxChunks) * PG_VEC4_BYTES);
+ const tempOut = mk(2 * PG * (s * maxChunks) * PG_VEC4_BYTES);
+
+ const countsBytes = num_columns * 4;
+ const offsetsBytes = num_columns * 4;
+ const countsA = mk(countsBytes);
+ const countsB = mk(countsBytes);
+ const offsetsA = mk(offsetsBytes);
+ const offsetsB = mk(offsetsBytes);
+
+ const perLevelChunkPlan: GPUBuffer[] = [];
+ const perLevelScatterPlan: GPUBuffer[] = [];
+ const perLevelCarryPlan: GPUBuffer[] = [];
+ const perLevelTotals: GPUBuffer[] = [];
+ const chunkPlanBytes = 2 * s * maxChunks * 4;
+ const scatterPlanBytes = s * maxChunks * 4;
+ const carryPlanBytes = 2 * num_columns * 4;
+ const totalsBytes = Math.max(40, Math.ceil(40 / 16) * 16);
+ for (let lvl = 0; lvl < max_levels; lvl++) {
+ perLevelChunkPlan.push(mk(chunkPlanBytes));
+ perLevelScatterPlan.push(mk(scatterPlanBytes));
+ perLevelCarryPlan.push(mk(carryPlanBytes));
+ perLevelTotals.push(mk(totalsBytes, GPUBufferUsage.INDIRECT | GPUBufferUsage.COPY_DST));
+ }
+
+ const ub = (bytes: number): GPUBuffer =>
+ device.createBuffer({ size: Math.max(16, bytes), usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST });
+ const metaParams = ub(16);
+ const activeParams = ub(16);
+ const plannerParams = ub(16);
+ const marshalConsts = ub(16);
+ const scatterConsts = ub(16);
+ const carryConsts = ub(16);
+ const v2RunParams = ub(16);
+
+ return {
+ activeA, activeB, chainBuf, tempOut,
+ countsA, countsB, offsetsA, offsetsB,
+ perLevelChunkPlan, perLevelScatterPlan, perLevelCarryPlan, perLevelTotals,
+ metaParams, activeParams, plannerParams,
+ marshalConsts, scatterConsts, carryConsts, v2RunParams,
+ M, maxChunks, perThread: per_thread,
+ };
+}
+
+function destroyScratch(scratch: Scratch): void {
+ scratch.activeA.destroy();
+ scratch.activeB.destroy();
+ scratch.chainBuf.destroy();
+ scratch.tempOut.destroy();
+ scratch.countsA.destroy();
+ scratch.countsB.destroy();
+ scratch.offsetsA.destroy();
+ scratch.offsetsB.destroy();
+ for (const b of scratch.perLevelChunkPlan) b.destroy();
+ for (const b of scratch.perLevelScatterPlan) b.destroy();
+ for (const b of scratch.perLevelCarryPlan) b.destroy();
+ for (const b of scratch.perLevelTotals) b.destroy();
+ scratch.metaParams.destroy();
+ scratch.activeParams.destroy();
+ scratch.plannerParams.destroy();
+ scratch.marshalConsts.destroy();
+ scratch.scatterConsts.destroy();
+ scratch.carryConsts.destroy();
+ scratch.v2RunParams.destroy();
+}
+
+/**
+ * Run the v2 pair-tree MSM bucket-accumulate for ALL pippenger windows
+ * in a single GPU submit.
+ *
+ * On return the caller's running_x / running_y / bucket_active buffers
+ * hold each bucket's reduced packed point (or 0/inactive marker) ready
+ * for batch_affine_finalize_collect to consume. The caller's val_idx /
+ * row_ptr / cached-bases buffers are read-only.
+ */
+export async function runSmvpV2PairTree(opts: SmvpV2PairTreeOptions): Promise {
+ const {
+ device, shaderManager, num_subtasks, num_columns, input_size,
+ val_idx_buf, row_ptr_buf, point_x_buf, point_y_buf,
+ running_x_buf, running_y_buf, bucket_active_buf,
+ } = opts;
+ const s = opts.s ?? 16;
+ const tpb = opts.tpb ?? 256;
+ const per_thread = opts.per_thread ?? Math.max(1, Math.ceil(num_columns / tpb));
+ const wgi = opts.wgi ?? 64;
+ const max_levels = opts.max_levels ?? 8;
+ if (tpb * per_thread < num_columns) {
+ throw new Error(`smvp_v2_pair_tree: tpb*per_thread (${tpb}*${per_thread}=${tpb * per_thread}) must be >= num_columns (${num_columns}).`);
+ }
+
+ const pipelines = await compileAll(device, shaderManager, wgi, s, tpb, per_thread);
+ const scratch = allocScratch(device, num_columns, input_size, s, max_levels, tpb, per_thread);
+ const M = scratch.M;
+
+ const padLeft = input_size;
+ const padRight = input_size + 1;
+ const discard = input_size + 2;
+ device.queue.writeBuffer(scratch.metaParams, 0, new Uint32Array([num_columns, num_columns, 0, 0]));
+ device.queue.writeBuffer(scratch.activeParams, 0, new Uint32Array([input_size, M, 0, 0]));
+ device.queue.writeBuffer(scratch.plannerParams, 0, new Uint32Array([num_columns, padLeft, padRight, discard]));
+ device.queue.writeBuffer(scratch.marshalConsts, 0, new Uint32Array([M, 0, 0, 0]));
+ device.queue.writeBuffer(scratch.scatterConsts, 0, new Uint32Array([M, 0, 0, 0]));
+ device.queue.writeBuffer(scratch.carryConsts, 0, new Uint32Array([M, M, 0, 0]));
+ device.queue.writeBuffer(scratch.v2RunParams, 0, new Uint32Array([num_columns, M, 0, 0]));
+
+ // Pad pair init. The planner's tail pad-fill writes chunk_plan entries
+ // (padLeft, padRight) and scatter_plan entries = discard for the
+ // unused slots of the last partial chunk. The disjoint kernel will
+ // then compute a (garbage) affine add of active_sums[padLeft] and
+ // active_sums[padRight] and scatter the result to
+ // active_sums_new[discard]. Pad slots must have distinct x so the
+ // affine-add formula doesn't divide by zero. Discard slot is a
+ // never-read tail slot reserved beyond the real bucket data.
+ const padPair = new Uint32Array(2 * PG * 4);
+ for (let i = 0; i < padPair.length; i++) padPair[i] = (0x9e3779b9 * (i + 1)) >>> 0;
+ if (padPair[0] === padPair[PG * 4]) padPair[PG * 4] ^= 1;
+ const planeXPadByteOff = PG * padLeft * PG_VEC4_BYTES;
+ const planeYPadByteOff = PG * M * PG_VEC4_BYTES + PG * padLeft * PG_VEC4_BYTES;
+ device.queue.writeBuffer(scratch.activeA, planeXPadByteOff, padPair as BufferSource);
+ device.queue.writeBuffer(scratch.activeA, planeYPadByteOff, padPair as BufferSource);
+ device.queue.writeBuffer(scratch.activeB, planeXPadByteOff, padPair as BufferSource);
+ device.queue.writeBuffer(scratch.activeB, planeYPadByteOff, padPair as BufferSource);
+
+ const encoder = device.createCommandEncoder();
+ let totalPasses = 0;
+ const directPass = (pipe: GPUComputePipeline, bind: GPUBindGroup, x: number, y = 1, z = 1): void => {
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipe);
+ pass.setBindGroup(0, bind);
+ pass.dispatchWorkgroups(x, y, z);
+ pass.end();
+ totalPasses++;
+ };
+ const indirectPass = (pipe: GPUComputePipeline, bind: GPUBindGroup, argsBuf: GPUBuffer, byteOffset: number): void => {
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipe);
+ pass.setBindGroup(0, bind);
+ pass.dispatchWorkgroupsIndirect(argsBuf, byteOffset);
+ pass.end();
+ totalPasses++;
+ };
+
+ const valIdxStride = input_size * 4;
+ const rowPtrStride = (num_columns + 1) * 4;
+ const runningStride = num_columns * ELEMENT_BYTES;
+ const bucketActiveStride = num_columns * 4;
+
+ for (let st = 0; st < num_subtasks; st++) {
+ const valIdxView = { buffer: val_idx_buf, offset: st * valIdxStride, size: valIdxStride } as const;
+ const rowPtrView = { buffer: row_ptr_buf, offset: st * rowPtrStride, size: rowPtrStride } as const;
+
+ const metaBind = device.createBindGroup({
+ layout: pipelines.layouts.meta,
+ entries: [
+ { binding: 0, resource: rowPtrView },
+ { binding: 1, resource: { buffer: scratch.countsA } },
+ { binding: 2, resource: { buffer: scratch.offsetsA } },
+ { binding: 3, resource: { buffer: scratch.metaParams } },
+ ],
+ });
+ directPass(pipelines.csrMeta, metaBind, Math.ceil(num_columns / wgi));
+
+ const activeBind = device.createBindGroup({
+ layout: pipelines.layouts.active,
+ entries: [
+ { binding: 0, resource: valIdxView },
+ { binding: 1, resource: { buffer: point_x_buf } },
+ { binding: 2, resource: { buffer: point_y_buf } },
+ { binding: 3, resource: { buffer: scratch.activeA } },
+ { binding: 4, resource: { buffer: scratch.activeParams } },
+ ],
+ });
+ directPass(pipelines.csrActive, activeBind, Math.ceil(input_size / wgi));
+
+ let curActive: GPUBuffer = scratch.activeA;
+ let nextActive: GPUBuffer = scratch.activeB;
+ let curCounts: GPUBuffer = scratch.countsA;
+ let curOffsets: GPUBuffer = scratch.offsetsA;
+ let nextCounts: GPUBuffer = scratch.countsB;
+ let nextOffsets: GPUBuffer = scratch.offsetsB;
+
+ for (let lvl = 0; lvl < max_levels; lvl++) {
+ const chunkPlanBuf = scratch.perLevelChunkPlan[lvl];
+ const scatterPlanBuf = scratch.perLevelScatterPlan[lvl];
+ const carryPlanBuf = scratch.perLevelCarryPlan[lvl];
+ const totalsBuf = scratch.perLevelTotals[lvl];
+
+ const plannerBind = device.createBindGroup({
+ layout: pipelines.layouts.planner,
+ entries: [
+ { binding: 0, resource: { buffer: curCounts } },
+ { binding: 1, resource: { buffer: curOffsets } },
+ { binding: 2, resource: { buffer: chunkPlanBuf } },
+ { binding: 3, resource: { buffer: scatterPlanBuf } },
+ { binding: 4, resource: { buffer: carryPlanBuf } },
+ { binding: 5, resource: { buffer: nextCounts } },
+ { binding: 6, resource: { buffer: nextOffsets } },
+ { binding: 7, resource: { buffer: totalsBuf } },
+ { binding: 8, resource: { buffer: scratch.plannerParams } },
+ ],
+ });
+ directPass(pipelines.planner, plannerBind, 1);
+
+ const marshalBind = device.createBindGroup({
+ layout: pipelines.layouts.marshal,
+ entries: [
+ { binding: 0, resource: { buffer: chunkPlanBuf } },
+ { binding: 1, resource: { buffer: curActive } },
+ { binding: 2, resource: { buffer: scratch.chainBuf } },
+ { binding: 3, resource: { buffer: totalsBuf } },
+ { binding: 4, resource: { buffer: scratch.marshalConsts } },
+ ],
+ });
+ indirectPass(pipelines.marshal, marshalBind, totalsBuf, 16);
+
+ const disjointBind = device.createBindGroup({
+ layout: pipelines.layouts.disjoint,
+ entries: [
+ { binding: 0, resource: { buffer: scratch.chainBuf } },
+ { binding: 1, resource: { buffer: scratch.chainBuf } },
+ { binding: 2, resource: { buffer: scratch.tempOut } },
+ { binding: 3, resource: { buffer: totalsBuf } },
+ ],
+ });
+ indirectPass(pipelines.disjoint, disjointBind, totalsBuf, 16);
+
+ const scatterBind = device.createBindGroup({
+ layout: pipelines.layouts.scatter,
+ entries: [
+ { binding: 0, resource: { buffer: scatterPlanBuf } },
+ { binding: 1, resource: { buffer: scratch.tempOut } },
+ { binding: 2, resource: { buffer: nextActive } },
+ { binding: 3, resource: { buffer: totalsBuf } },
+ { binding: 4, resource: { buffer: scratch.scatterConsts } },
+ ],
+ });
+ indirectPass(pipelines.scatter, scatterBind, totalsBuf, 16);
+
+ const carryBind = device.createBindGroup({
+ layout: pipelines.layouts.carry,
+ entries: [
+ { binding: 0, resource: { buffer: carryPlanBuf } },
+ { binding: 1, resource: { buffer: curActive } },
+ { binding: 2, resource: { buffer: nextActive } },
+ { binding: 3, resource: { buffer: totalsBuf } },
+ { binding: 4, resource: { buffer: scratch.carryConsts } },
+ ],
+ });
+ indirectPass(pipelines.carry, carryBind, totalsBuf, 28);
+
+ [curActive, nextActive] = [nextActive, curActive];
+ [curCounts, nextCounts] = [nextCounts, curCounts];
+ [curOffsets, nextOffsets] = [nextOffsets, curOffsets];
+ }
+
+ const subtaskBucketOff = st * num_columns;
+ const v2RunBind = device.createBindGroup({
+ layout: pipelines.layouts.v2Run,
+ entries: [
+ { binding: 0, resource: { buffer: curActive } },
+ { binding: 1, resource: { buffer: curCounts } },
+ { binding: 2, resource: { buffer: curOffsets } },
+ { binding: 3, resource: { buffer: running_x_buf, offset: subtaskBucketOff * ELEMENT_BYTES, size: runningStride } },
+ { binding: 4, resource: { buffer: running_y_buf, offset: subtaskBucketOff * ELEMENT_BYTES, size: runningStride } },
+ { binding: 5, resource: { buffer: bucket_active_buf, offset: subtaskBucketOff * 4, size: bucketActiveStride } },
+ { binding: 6, resource: { buffer: scratch.v2RunParams } },
+ ],
+ });
+ directPass(pipelines.v2ToRunning, v2RunBind, Math.ceil(num_columns / wgi));
+ }
+
+ const t0 = performance.now();
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ const gpu_wall_ms = performance.now() - t0;
+
+ destroyScratch(scratch);
+
+ return {
+ levels_per_window: max_levels,
+ num_subtasks,
+ num_columns,
+ total_passes: totalPasses,
+ gpu_wall_ms,
+ };
+}
+
+export function maxChunksUpperBound(input_size: number, num_columns: number, s: number): number {
+ return Math.max(1, Math.ceil(input_size / 2 / s) + num_columns);
+}
+
+export const sizes = {
+ activeSumsBytes(input_size: number): number {
+ const M = input_size + 2;
+ return 2 * PG * M * 16;
+ },
+ chainBufBytes(input_size: number, num_columns: number, s: number): number {
+ const T = maxChunksUpperBound(input_size, num_columns, s);
+ return 2 * PG * (2 * s * T) * 16;
+ },
+ tempOutBytes(input_size: number, num_columns: number, s: number): number {
+ const T = maxChunksUpperBound(input_size, num_columns, s);
+ return 2 * PG * (s * T) * 16;
+ },
+ chunkPlanBytes(input_size: number, num_columns: number, s: number): number {
+ const T = maxChunksUpperBound(input_size, num_columns, s);
+ return 2 * s * T * 4;
+ },
+ scatterPlanBytes(input_size: number, num_columns: number, s: number): number {
+ const T = maxChunksUpperBound(input_size, num_columns, s);
+ return s * T * 4;
+ },
+ carryPlanBytes(num_columns: number): number {
+ return 2 * num_columns * 4;
+ },
+ countsBytes(num_columns: number): number {
+ return num_columns * 4;
+ },
+ offsetsBytes(num_columns: number): number {
+ return num_columns * 4;
+ },
+};
diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts
index 741d3480f737..bef7e380beeb 100644
--- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts
+++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts
@@ -1,6 +1,6 @@
// AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT.
// Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate.
-// 48 shader sources inlined.
+// 70 shader sources inlined.
/* eslint-disable */
@@ -1351,6 +1351,1870 @@ fn main(@builtin(global_invocation_id) gid: vec3) {
}
`;
+export const ba_carry_copy_bench = `{{> structs }}
+
+// Carry-copy kernel for the bin-packed pair-tree MSM bucket-accumulate.
+//
+// For each carry slot t, copies one packed (x, y) point from
+// active_sums_old[carry_plan[2*t + 0]] to
+// active_sums_new[carry_plan[2*t + 1]].
+//
+// Used when a bucket has an odd active count at the current level:
+// floor(N_b / 2) elements get paired and produce floor(N_b / 2) sums
+// in the next level, plus the (N_b mod 2 == 1) carry element propagates
+// forward unchanged.
+//
+// Pure memory shuffle, no field arithmetic.
+//
+// params.x = T (number of carry-copies / threads)
+// params.y = M_old (active_sums_old size, vec4-stride scaling)
+// params.z = M_new (active_sums_new size, vec4-stride scaling)
+
+const PG: u32 = 2u;
+
+@group(0) @binding(0) var carry_plan: array;
+@group(0) @binding(1) var active_sums_old: array>;
+@group(0) @binding(2) var active_sums_new: array>;
+@group(0) @binding(3) var params: vec4;
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let T = params.x;
+ let M_old = params.y;
+ let M_new = params.z;
+ let t = gid.x;
+ if (t >= T) { return; }
+
+ let src_idx = carry_plan[2u * t + 0u];
+ let dst_idx = carry_plan[2u * t + 1u];
+
+ let old_plane_x = 0u * PG * M_old;
+ let old_plane_y = 1u * PG * M_old;
+ let new_plane_x = 0u * PG * M_new;
+ let new_plane_y = 1u * PG * M_new;
+
+ let src_x = old_plane_x + PG * src_idx;
+ let src_y = old_plane_y + PG * src_idx;
+ let dst_x = new_plane_x + PG * dst_idx;
+ let dst_y = new_plane_y + PG * dst_idx;
+
+ active_sums_new[dst_x + 0u] = active_sums_old[src_x + 0u];
+ active_sums_new[dst_x + 1u] = active_sums_old[src_x + 1u];
+ active_sums_new[dst_y + 0u] = active_sums_old[src_y + 0u];
+ active_sums_new[dst_y + 1u] = active_sums_old[src_y + 1u];
+
+ {{{ recompile }}}
+}
+`;
+
+export const ba_carry_copy_prod = `{{> structs }}
+
+// Carry-copy kernel — prod variant for the v2 pair-tree integration.
+// num_carries is read from the planner's totals[1] and dispatch is
+// indirect via totals[7..9].
+
+const PG: u32 = 2u;
+
+@group(0) @binding(0) var carry_plan: array;
+@group(0) @binding(1) var active_sums_old: array>;
+@group(0) @binding(2) var active_sums_new: array>;
+@group(0) @binding(3) var totals: array;
+@group(0) @binding(4) var consts: vec4;
+// consts.x = M_old
+// consts.y = M_new
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let T = totals[1];
+ let M_old = consts.x;
+ let M_new = consts.y;
+ let t = gid.x;
+ if (t >= T) { return; }
+
+ let src_idx = carry_plan[2u * t + 0u];
+ let dst_idx = carry_plan[2u * t + 1u];
+
+ let old_plane_x = 0u * PG * M_old;
+ let old_plane_y = 1u * PG * M_old;
+ let new_plane_x = 0u * PG * M_new;
+ let new_plane_y = 1u * PG * M_new;
+
+ let src_x = old_plane_x + PG * src_idx;
+ let src_y = old_plane_y + PG * src_idx;
+ let dst_x = new_plane_x + PG * dst_idx;
+ let dst_y = new_plane_y + PG * dst_idx;
+
+ active_sums_new[dst_x + 0u] = active_sums_old[src_x + 0u];
+ active_sums_new[dst_x + 1u] = active_sums_old[src_x + 1u];
+ active_sums_new[dst_y + 0u] = active_sums_old[src_y + 0u];
+ active_sums_new[dst_y + 1u] = active_sums_old[src_y + 1u];
+
+ {{{ recompile }}}
+}
+`;
+
+export const ba_fused_super_bench = `{{> structs }}
+{{> bigint_funcs }}
+{{> montgomery_product_funcs }}
+{{> field_funcs }}
+{{> fr_pow_funcs }}
+{{> bigint_by_funcs }}
+{{> by_inverse_a_funcs }}
+
+{{{ dec_unpack }}}
+
+{{{ dec_pack }}}
+
+// Fused super-kernel for the bin-packed pair-tree MSM bucket-accumulate.
+//
+// Combines marshal + disjoint + scatter into one kernel. Each thread t
+// handles one chunk of S pairs:
+// 1. Read 2*S source indices from chunk_plan (idx_l, idx_r per slot).
+// 2. Read S destination indices from scatter_plan.
+// 3. Load S pair-x values from active_sums_old, compute S dx values
+// and forward prefix product, all in registers.
+// 4. Single fr_inv_by_a on the prefix product.
+// 5. Backward peel: per slot k from S-1 down to 0:
+// - load .x and .y for both operands
+// - lean affine add -> R_x, R_y
+// - write directly to active_sums_new at scatter_plan[t*S + k]
+// - update inv for next (smaller-k) iteration
+//
+// vs v2 (4 kernels: marshal, disjoint, scatter, carry): the chain_buf
+// and tempOut scratch buffers are eliminated. All intermediate state
+// lives in registers. Per-level dispatch count drops from 4 to 2
+// (fused + carry).
+//
+// PARAMS:
+// params.x = T_chunks (active threads, one per chunk)
+// params.y = M_old (active_sums_old vec4-stride length)
+// params.z = M_new (active_sums_new vec4-stride length)
+//
+// Layout (both active_sums buffers): 2 planes (P.x, P.y), PG=2 vec4 per
+// element. plane_p flat vec4 base = p * PG * M, element e at offset
+// PG * e.
+
+const S: u32 = {{ s }}u;
+const PG: u32 = 2u;
+
+@group(0) @binding(0) var chunk_plan: array;
+@group(0) @binding(1) var scatter_plan: array;
+@group(0) @binding(2) var active_sums_old: array>;
+@group(0) @binding(3) var active_sums_new: array>;
+@group(0) @binding(4) var params: vec4;
+
+fn load_active_x(idx: u32, M: u32) -> BigInt {
+ let plane_base = 0u * PG * M;
+ let base = plane_base + PG * idx;
+ let q0 = active_sums_old[base + 0u];
+ let q1 = active_sums_old[base + 1u];
+ var w: array;
+ w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w;
+ w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w;
+ return unpack256_to_limbs(w);
+}
+
+fn load_active_y(idx: u32, M: u32) -> BigInt {
+ let plane_base = 1u * PG * M;
+ let base = plane_base + PG * idx;
+ let q0 = active_sums_old[base + 0u];
+ let q1 = active_sums_old[base + 1u];
+ var w: array;
+ w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w;
+ w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w;
+ return unpack256_to_limbs(w);
+}
+
+fn store_active_new(plane: u32, idx: u32, M: u32, val: ptr) {
+ let plane_base = plane * PG * M;
+ let base = plane_base + PG * idx;
+ let w = pack_limbs_to_256(val);
+ active_sums_new[base + 0u] = vec4(w[0], w[1], w[2], w[3]);
+ active_sums_new[base + 1u] = vec4(w[4], w[5], w[6], w[7]);
+}
+
+fn get_r() -> BigInt {
+ var r: BigInt;
+{{{ r_limbs }}}
+ return r;
+}
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let T = params.x;
+ let M_old = params.y;
+ let M_new = params.z;
+ let t = gid.x;
+ if (t >= T) { return; }
+
+ let chunk_base = 2u * S * t;
+
+ // Forward: compute S dx values and accumulate prefix product.
+ // Read pair indices from chunk_plan, load .x for each operand, compute dx.
+ var pref: array;
+ var acc: BigInt = get_r();
+ for (var k: u32 = 0u; k < S; k = k + 1u) {
+ let idx_l = chunk_plan[chunk_base + 2u * k + 0u];
+ let idx_r = chunk_plan[chunk_base + 2u * k + 1u];
+ var p_lx: BigInt = load_active_x(idx_l, M_old);
+ var p_rx: BigInt = load_active_x(idx_r, M_old);
+ var dx: BigInt = fr_sub(&p_rx, &p_lx);
+ if (k == 0u) {
+ acc = dx;
+ } else {
+ acc = montgomery_product(&acc, &dx);
+ }
+ pref[k] = acc;
+ }
+
+ // Single inversion per chunk.
+ var inv: BigInt = fr_inv_by_a(acc);
+
+ // Backward peel: emit S pair sums, scatter to active_sums_new.
+ for (var jj: u32 = 0u; jj < S; jj = jj + 1u) {
+ let k = S - 1u - jj;
+ let idx_l = chunk_plan[chunk_base + 2u * k + 0u];
+ let idx_r = chunk_plan[chunk_base + 2u * k + 1u];
+
+ var p_lx: BigInt = load_active_x(idx_l, M_old);
+ var p_ly: BigInt = load_active_y(idx_l, M_old);
+ var p_rx: BigInt = load_active_x(idx_r, M_old);
+ var p_ry: BigInt = load_active_y(idx_r, M_old);
+
+ var inv_dx: BigInt;
+ if (k == 0u) {
+ inv_dx = inv;
+ } else {
+ var pp = pref[k - 1u];
+ inv_dx = montgomery_product(&inv, &pp);
+ }
+
+ var lambda: BigInt = fr_sub(&p_ry, &p_ly);
+ lambda = montgomery_product(&lambda, &inv_dx);
+ var r_x: BigInt = montgomery_product(&lambda, &lambda);
+ r_x = fr_sub(&r_x, &p_lx);
+ r_x = fr_sub(&r_x, &p_rx);
+ var r_y: BigInt = fr_sub(&p_lx, &r_x);
+ r_y = montgomery_product(&lambda, &r_y);
+ r_y = fr_sub(&r_y, &p_ly);
+
+ let dst_idx = scatter_plan[t * S + k];
+ store_active_new(0u, dst_idx, M_new, &r_x);
+ store_active_new(1u, dst_idx, M_new, &r_y);
+
+ if (k > 0u) {
+ var dx_back: BigInt = fr_sub(&p_rx, &p_lx);
+ inv = montgomery_product(&inv, &dx_back);
+ }
+ }
+
+ {{{ recompile }}}
+}
+`;
+
+export const ba_marshal_chain_bench = `{{> structs }}
+
+// Marshal kernel for the bench-msm-chain pipeline. Transposes a CSR
+// point list (sorted by bucket) into the strided SoA layout the
+// ba_rev_packed_carry_bench chain kernel consumes.
+//
+// Input layout (point_pool):
+// 2 planes (P.x, P.y), each PG=2 vec4 per element, params.y elements total.
+// Plane p at point idx i: vec4 indices p*PG*N + PG*i + {0,1}.
+// Convention: point_pool[0] is the "decoy" — used as the seed for every
+// chunk so the chain kernel's first dx (= P_0.x - seed.x) is well-
+// defined. csr_indices values are in [1, N), never 0.
+//
+// Output layout (chain_buf):
+// 4 planes (A.x, A.y, P.x, P.y), each PG=2 vec4 per element, T*S
+// elements per plane. Plane p at strided element e = t + i*T: vec4
+// indices p*PG*(T*S) + PG*e + {0,1}.
+//
+// Per chunk-thread t:
+// - csr_start = chunk_plan[2*t + 1] (chunk_plan[2*t] = bucket_id, unused here)
+// - Seed at index t (planes 0,1) := point_pool[0] (universal decoy)
+// - For i in 0..S: P_i at index e = t + i*T (planes 2,3)
+// := point_pool[csr_indices[csr_start + i]]
+//
+// The chain kernel then produces S pair-sums per chunk. The S/2 odd-
+// indexed outputs (R_1, R_3, ..., R_{S-1}) are disjoint pair sums of
+// {P_0..P_{S-1}}; the even outputs (R_0, R_2, ...) incorporate the
+// decoy or share a P with the next odd output and are discarded by the
+// subsequent reduce pass.
+//
+// Pure memory-shuffle kernel: no field arithmetic. Reads are coalesced
+// because consecutive threads t, t+1 read adjacent csr_indices entries
+// and the gathered point coords are written to adjacent vec4 slots
+// (PG*e for e=t, t+1, ...).
+
+const S: u32 = {{ s }}u;
+const PG: u32 = 2u;
+
+@group(0) @binding(0) var csr_indices: array;
+@group(0) @binding(1) var chunk_plan: array;
+@group(0) @binding(2) var point_pool: array>;
+@group(0) @binding(3) var chain_buf: array>;
+@group(0) @binding(4) var params: vec4;
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let T = params.x;
+ let N = params.y;
+ let t = gid.x;
+ if (t >= T) { return; }
+
+ let csr_start = chunk_plan[2u * t + 1u];
+
+ let chain_N = T * S;
+ let chain_plane = PG * chain_N;
+ let chain_ax_base = 0u * chain_plane;
+ let chain_ay_base = 1u * chain_plane;
+ let chain_px_base = 2u * chain_plane;
+ let chain_py_base = 3u * chain_plane;
+
+ let pool_plane = PG * N;
+ let pool_px_base = 0u * pool_plane;
+ let pool_py_base = 1u * pool_plane;
+
+ // Seed (A.x, A.y at index t) := point_pool[0] (decoy).
+ let decoy_x_off = pool_px_base + PG * 0u;
+ let decoy_y_off = pool_py_base + PG * 0u;
+ let seed_x_off = chain_ax_base + PG * t;
+ let seed_y_off = chain_ay_base + PG * t;
+ chain_buf[seed_x_off + 0u] = point_pool[decoy_x_off + 0u];
+ chain_buf[seed_x_off + 1u] = point_pool[decoy_x_off + 1u];
+ chain_buf[seed_y_off + 0u] = point_pool[decoy_y_off + 0u];
+ chain_buf[seed_y_off + 1u] = point_pool[decoy_y_off + 1u];
+
+ // Gather S points from csr_indices[csr_start..csr_start+S] into the
+ // strided P-planes at indices e = t + i*T for i in 0..S.
+ for (var i = 0u; i < S; i = i + 1u) {
+ let pt_idx = csr_indices[csr_start + i];
+ let e = t + i * T;
+ let pool_x_off = pool_px_base + PG * pt_idx;
+ let pool_y_off = pool_py_base + PG * pt_idx;
+ let chain_px_off = chain_px_base + PG * e;
+ let chain_py_off = chain_py_base + PG * e;
+ chain_buf[chain_px_off + 0u] = point_pool[pool_x_off + 0u];
+ chain_buf[chain_px_off + 1u] = point_pool[pool_x_off + 1u];
+ chain_buf[chain_py_off + 0u] = point_pool[pool_y_off + 0u];
+ chain_buf[chain_py_off + 1u] = point_pool[pool_y_off + 1u];
+ }
+
+ {{{ recompile }}}
+}
+`;
+
+export const ba_marshal_pairs_bench = `{{> structs }}
+
+// Marshal kernel for the bin-packed pair-tree MSM bucket-accumulate.
+//
+// Reads (idx_l, idx_r) operand indices per pair from chunk_plan,
+// fetches the corresponding packed 8x u32 points from an active_sums
+// buffer (2-plane SoA), and writes them into the disjoint kernel's
+// strided input layout.
+//
+// Used both at level 0 (active_sums = bucket-sorted point pool) and
+// at levels 1+ (active_sums = previous level's pair-sum + carry
+// outputs). The kernel is bucket-agnostic; the planner has packed
+// each chunk's S pairs from whatever buckets fit, and chunk_plan
+// encodes the operand source indices.
+//
+// chunk_plan layout: 2 * S u32 per chunk
+// chunk_plan[2 * (t * S + k) + 0] = idx_left (active_sums index)
+// chunk_plan[2 * (t * S + k) + 1] = idx_right (active_sums index)
+//
+// active_sums layout: 2 planes (P.x, P.y), PG=2 vec4 per element,
+// M_in elements per plane (params.y).
+//
+// chain_buf layout: 2 planes (P.x, P.y), PG=2 vec4 per element,
+// 2 * S * T elements per plane. Slot (t, 2k+0) holds left, slot
+// (t, 2k+1) holds right at the disjoint kernel's strided positions
+// e = t + i * T for i = 2k, 2k+1.
+
+const S: u32 = {{ s }}u;
+const PG: u32 = 2u;
+
+@group(0) @binding(0) var chunk_plan: array;
+@group(0) @binding(1) var active_sums: array>;
+@group(0) @binding(2) var chain_buf: array>;
+@group(0) @binding(3) var params: vec4;
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let T = params.x;
+ let M_in = params.y;
+ let t = gid.x;
+ if (t >= T) { return; }
+
+ let chain_N = 2u * S * T;
+ let chain_plane_x = 0u * PG * chain_N;
+ let chain_plane_y = 1u * PG * chain_N;
+
+ let active_plane_x = 0u * PG * M_in;
+ let active_plane_y = 1u * PG * M_in;
+
+ let chunk_base = 2u * S * t;
+ for (var k: u32 = 0u; k < S; k = k + 1u) {
+ let idx_l = chunk_plan[chunk_base + 2u * k + 0u];
+ let idx_r = chunk_plan[chunk_base + 2u * k + 1u];
+
+ let e_l = t + (2u * k + 0u) * T;
+ let e_r = t + (2u * k + 1u) * T;
+
+ let src_lx = active_plane_x + PG * idx_l;
+ let src_ly = active_plane_y + PG * idx_l;
+ let src_rx = active_plane_x + PG * idx_r;
+ let src_ry = active_plane_y + PG * idx_r;
+
+ let dst_lx = chain_plane_x + PG * e_l;
+ let dst_ly = chain_plane_y + PG * e_l;
+ let dst_rx = chain_plane_x + PG * e_r;
+ let dst_ry = chain_plane_y + PG * e_r;
+
+ chain_buf[dst_lx + 0u] = active_sums[src_lx + 0u];
+ chain_buf[dst_lx + 1u] = active_sums[src_lx + 1u];
+ chain_buf[dst_ly + 0u] = active_sums[src_ly + 0u];
+ chain_buf[dst_ly + 1u] = active_sums[src_ly + 1u];
+ chain_buf[dst_rx + 0u] = active_sums[src_rx + 0u];
+ chain_buf[dst_rx + 1u] = active_sums[src_rx + 1u];
+ chain_buf[dst_ry + 0u] = active_sums[src_ry + 0u];
+ chain_buf[dst_ry + 1u] = active_sums[src_ry + 1u];
+ }
+
+ {{{ recompile }}}
+}
+`;
+
+export const ba_marshal_pairs_prod = `{{> structs }}
+
+// Marshal kernel — prod variant for the v2 pair-tree integration.
+//
+// Same indexing math as ba_marshal_pairs_bench. The only structural
+// change: the per-level T (= num_chunks) is read from the planner's
+// totals[3] storage output instead of a host-set uniform, and the
+// host dispatches via dispatchWorkgroupsIndirect(totals, 16). This
+// dispatches exactly ceil(num_chunks / WG) workgroups so no pad
+// chunks are computed.
+
+const S: u32 = {{ s }}u;
+const PG: u32 = 2u;
+
+@group(0) @binding(0) var chunk_plan: array;
+@group(0) @binding(1) var active_sums: array>;
+@group(0) @binding(2) var chain_buf: array>;
+@group(0) @binding(3) var totals: array;
+@group(0) @binding(4) var consts: vec4;
+// consts.x = M_in
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let T = totals[3];
+ let M_in = consts.x;
+ let t = gid.x;
+ if (t >= T) { return; }
+
+ let chain_N = 2u * S * T;
+ let chain_plane_x = 0u * PG * chain_N;
+ let chain_plane_y = 1u * PG * chain_N;
+
+ let active_plane_x = 0u * PG * M_in;
+ let active_plane_y = 1u * PG * M_in;
+
+ let chunk_base = 2u * S * t;
+ for (var k: u32 = 0u; k < S; k = k + 1u) {
+ let idx_l = chunk_plan[chunk_base + 2u * k + 0u];
+ let idx_r = chunk_plan[chunk_base + 2u * k + 1u];
+
+ let e_l = t + (2u * k + 0u) * T;
+ let e_r = t + (2u * k + 1u) * T;
+
+ let src_lx = active_plane_x + PG * idx_l;
+ let src_ly = active_plane_y + PG * idx_l;
+ let src_rx = active_plane_x + PG * idx_r;
+ let src_ry = active_plane_y + PG * idx_r;
+
+ let dst_lx = chain_plane_x + PG * e_l;
+ let dst_ly = chain_plane_y + PG * e_l;
+ let dst_rx = chain_plane_x + PG * e_r;
+ let dst_ry = chain_plane_y + PG * e_r;
+
+ chain_buf[dst_lx + 0u] = active_sums[src_lx + 0u];
+ chain_buf[dst_lx + 1u] = active_sums[src_lx + 1u];
+ chain_buf[dst_ly + 0u] = active_sums[src_ly + 0u];
+ chain_buf[dst_ly + 1u] = active_sums[src_ly + 1u];
+ chain_buf[dst_rx + 0u] = active_sums[src_rx + 0u];
+ chain_buf[dst_rx + 1u] = active_sums[src_rx + 1u];
+ chain_buf[dst_ry + 0u] = active_sums[src_ry + 0u];
+ chain_buf[dst_ry + 1u] = active_sums[src_ry + 1u];
+ }
+
+ {{{ recompile }}}
+}
+`;
+
+export const ba_marshal_tree_l0_bench = `{{> structs }}
+
+// Marshal kernel for the bench-msm-tree pair-tree pipeline: transposes
+// a CSR-sorted point index list into the 2-plane strided SoA layout
+// the ba_pair_disjoint_tree kernel consumes at level 0. Pure memory
+// shuffle, no field arithmetic.
+//
+// Input (point_pool):
+// 2 planes (P.x, P.y), each PG=2 vec4 per element, N pool elements.
+// Plane p flat vec4 indices: p*PG*N + PG*i + {0,1}.
+//
+// Output (chain_buf):
+// 2 planes (P.x, P.y), each PG=2 vec4 per element, 2*S*T elements
+// per plane. Plane p at strided element e = t + i*T: vec4 indices
+// p*PG*(2*S*T) + PG*e + {0,1}.
+//
+// Per chunk-thread t with CSR slice [csr_start, csr_start + 2*S):
+// For i in 0..2*S:
+// pt_idx = csr_indices[csr_start + i]
+// copy point_pool[pt_idx] (P.x, P.y) into chain_buf at e = t + i*T
+
+const S: u32 = {{ s }}u;
+const TWOS: u32 = 2u * S;
+const PG: u32 = 2u;
+
+@group(0) @binding(0) var csr_indices: array;
+@group(0) @binding(1) var chunk_plan: array;
+@group(0) @binding(2) var point_pool: array>;
+@group(0) @binding(3) var chain_buf: array>;
+@group(0) @binding(4) var params: vec4;
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let T = params.x;
+ let N = params.y;
+ let t = gid.x;
+ if (t >= T) { return; }
+
+ let csr_start = chunk_plan[2u * t + 1u];
+
+ let chain_N = TWOS * T;
+ let chain_plane = PG * chain_N;
+ let chain_px_base = 0u * chain_plane;
+ let chain_py_base = 1u * chain_plane;
+
+ let pool_plane = PG * N;
+ let pool_px_base = 0u * pool_plane;
+ let pool_py_base = 1u * pool_plane;
+
+ for (var i: u32 = 0u; i < TWOS; i = i + 1u) {
+ let pt_idx = csr_indices[csr_start + i];
+ let e = t + i * T;
+ let pool_x_off = pool_px_base + PG * pt_idx;
+ let pool_y_off = pool_py_base + PG * pt_idx;
+ let chain_px_off = chain_px_base + PG * e;
+ let chain_py_off = chain_py_base + PG * e;
+ chain_buf[chain_px_off + 0u] = point_pool[pool_x_off + 0u];
+ chain_buf[chain_px_off + 1u] = point_pool[pool_x_off + 1u];
+ chain_buf[chain_py_off + 0u] = point_pool[pool_y_off + 0u];
+ chain_buf[chain_py_off + 1u] = point_pool[pool_y_off + 1u];
+ }
+
+ {{{ recompile }}}
+}
+`;
+
+export const ba_pair_disjoint_bench = `{{> structs }}
+{{> bigint_funcs }}
+{{> montgomery_product_funcs }}
+{{> field_funcs }}
+{{> fr_pow_funcs }}
+{{> bigint_by_funcs }}
+{{> by_inverse_a_funcs }}
+
+{{{ dec_unpack }}}
+
+{{{ dec_pack }}}
+
+// Disjoint pair-sum kernel — each thread reduces 2*S input points to S
+// disjoint pair sums R_k = P_{2k} + P_{2k+1} (k in 0..S) using the
+// same forward-prefix / single-inversion / backward-peel batched-
+// inverse pattern as ba_rev_packed_carry, but with NO load-carry
+// overlap. Every kernel-output is a distinct pair sum suitable as
+// input to the next level of a pair-tree reduction — closes the 50%
+// kernel-efficiency loss inherent in the streaming chain kernel.
+//
+// Storage: SoA-packed 8x u32 per field (PG=2 vec4/elem).
+// Input planes (binding 0):
+// plane 0 (P.x): PG * N_in vec4, N_in = 2*S*T
+// plane 1 (P.y): PG * N_in vec4
+// Output planes (binding 2):
+// plane 0 (R.x): PG * N_out vec4, N_out = S*T
+// plane 1 (R.y): PG * N_out vec4
+//
+// Thread t reads P_i = (inp[plane c at index t + i*T] : c in {0,1}) for
+// i in 0..2S (strided => coalesced). Pair k pairs adjacent strided
+// slots: (P_{2k}, P_{2k+1}). Output R_k is written at index t + k*T in
+// plane c of outp (also strided, coalesced).
+//
+// dx values dx_k = P_{2k+1}.x - P_{2k}.x are all mutually independent
+// (no shared inputs across k), so the standard Montgomery batched
+// inverse trick applies as-is: ONE fr_inv_by_a per chunk of S.
+//
+// Same Karatsuba+Yuval montmul and BY-safegcd fr_inv_by_a as the
+// production stack and the chain kernel.
+
+const S: u32 = {{ s }}u;
+const PG: u32 = 2u;
+
+@group(0) @binding(0) var inp: array>;
+@group(0) @binding(1) var unused: array>;
+@group(0) @binding(2) var outp: array>;
+@group(0) @binding(3) var params: vec4;
+
+fn load_in(plane: u32, t: u32, i: u32, T: u32, N_in: u32) -> BigInt {
+ let plane_base = plane * PG * N_in;
+ let base = plane_base + PG * (t + i * T);
+ let q0 = inp[base + 0u];
+ let q1 = inp[base + 1u];
+ var w: array;
+ w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w;
+ w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w;
+ return unpack256_to_limbs(w);
+}
+
+fn store_out(plane: u32, t: u32, k: u32, T: u32, N_out: u32, val: ptr) {
+ let plane_base = plane * PG * N_out;
+ let base = plane_base + PG * (t + k * T);
+ let w = pack_limbs_to_256(val);
+ outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]);
+ outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]);
+}
+
+fn get_r() -> BigInt {
+ var r: BigInt;
+{{{ r_limbs }}}
+ return r;
+}
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let N_in = params.x;
+ let T = params.y;
+ let N_out = N_in / 2u;
+
+ let t = gid.x;
+ if (t >= T) { return; }
+
+ // Forward: prefix product of S independent dx values.
+ var pref: array;
+ var acc: BigInt = get_r();
+ for (var k: u32 = 0u; k < S; k = k + 1u) {
+ var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T, N_in);
+ var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T, N_in);
+ var dx: BigInt = fr_sub(&p_rx, &p_lx);
+ if (k == 0u) {
+ acc = dx;
+ } else {
+ acc = montgomery_product(&acc, &dx);
+ }
+ pref[k] = acc;
+ }
+
+ // One BY-safegcd inversion amortised over all S pair sums.
+ var inv: BigInt = fr_inv_by_a(acc);
+
+ // Backward peel: emit S disjoint pair sums.
+ for (var jj: u32 = 0u; jj < S; jj = jj + 1u) {
+ let k = S - 1u - jj;
+
+ var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T, N_in);
+ var p_ly: BigInt = load_in(1u, t, 2u * k + 0u, T, N_in);
+ var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T, N_in);
+ var p_ry: BigInt = load_in(1u, t, 2u * k + 1u, T, N_in);
+
+ var inv_dx: BigInt;
+ if (k == 0u) {
+ inv_dx = inv;
+ } else {
+ var pp = pref[k - 1u];
+ inv_dx = montgomery_product(&inv, &pp);
+ }
+
+ var lambda: BigInt = fr_sub(&p_ry, &p_ly);
+ lambda = montgomery_product(&lambda, &inv_dx);
+ var r_x: BigInt = montgomery_product(&lambda, &lambda);
+ r_x = fr_sub(&r_x, &p_lx);
+ r_x = fr_sub(&r_x, &p_rx);
+ var r_y: BigInt = fr_sub(&p_lx, &r_x);
+ r_y = montgomery_product(&lambda, &r_y);
+ r_y = fr_sub(&r_y, &p_ly);
+
+ store_out(0u, t, k, T, N_out, &r_x);
+ store_out(1u, t, k, T, N_out, &r_y);
+
+ // Advance inv to 1/pref[k-1] for the next (smaller) iteration.
+ if (k > 0u) {
+ var dx_back: BigInt = fr_sub(&p_rx, &p_lx);
+ inv = montgomery_product(&inv, &dx_back);
+ }
+ }
+
+ {{{ recompile }}}
+}
+`;
+
+export const ba_pair_disjoint_tree_bench = `{{> structs }}
+{{> bigint_funcs }}
+{{> montgomery_product_funcs }}
+{{> field_funcs }}
+{{> fr_pow_funcs }}
+{{> bigint_by_funcs }}
+{{> by_inverse_a_funcs }}
+
+{{{ dec_unpack }}}
+
+{{{ dec_pack }}}
+
+// Disjoint pair-sum kernel — tree variant. Each thread reduces 2*S
+// input points to S disjoint pair sums R_k = P_{2k} + P_{2k+1}, using
+// one batched fr_inv_by_a per chunk of S.
+//
+// vs ba_pair_disjoint_bench: writes outputs in the LAYOUT THE NEXT
+// PAIR-TREE LEVEL EXPECTS AS INPUT, eliminating the need for an
+// intervening marshal/reshuffle dispatch between levels.
+//
+// Strided read at level k: thread t reads input slot i at flat
+// in_pos(t, i) = t + i * T_curr (i in [0, 2*S))
+//
+// Strided write that next level reads correctly: thread t writes
+// output slot i at flat
+// out_pos(t, i) = (t >> 1) + (i + S * (t & 1)) * (T_curr >> 1)
+//
+// Derivation: next level uses T_next = T_curr / 2 threads. For
+// next-level thread t_n = t >> 1 to read its 2*S inputs in the right
+// pair-tree order (first S from prev thread (2*t_n), next S from prev
+// thread (2*t_n + 1)), the current level's output slots interleave:
+// odd-t writes go into the upper-S input slots of the next level's
+// thread (t >> 1), even-t into the lower-S slots.
+//
+// This preserves the per-bucket-pair invariant: at every level, the
+// disjoint pairs (P_{2j}, P_{2j+1}) belong to the same bucket pool,
+// so the lean affine formula is always combining points whose dx is
+// well-defined.
+//
+// PARAMS:
+// params.x = N_in = 2 * S * T_curr (total input elements per plane)
+// params.y = T_curr
+//
+// LAYOUT (both input and output buffers):
+// 2 planes (P.x, P.y), PG=2 vec4 per element.
+// Plane p flat index for vec4 access: p * PG * N_buf + PG * e + {0,1}
+// where N_buf is the elements-per-plane for that buffer.
+// Input buffer's N_buf = 2 * S * T_curr (= N_in).
+// Output buffer's N_buf = S * T_curr (= N_in / 2).
+
+const S: u32 = {{ s }}u;
+const PG: u32 = 2u;
+
+@group(0) @binding(0) var inp: array>;
+@group(0) @binding(1) var unused: array>;
+@group(0) @binding(2) var outp: array>;
+@group(0) @binding(3) var params: vec4;
+
+fn load_in(plane: u32, t: u32, i: u32, T: u32, N_in: u32) -> BigInt {
+ let plane_base = plane * PG * N_in;
+ let base = plane_base + PG * (t + i * T);
+ let q0 = inp[base + 0u];
+ let q1 = inp[base + 1u];
+ var w: array;
+ w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w;
+ w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w;
+ return unpack256_to_limbs(w);
+}
+
+fn store_out_tree(plane: u32, t: u32, k: u32, T_curr: u32, N_out: u32, val: ptr) {
+ // Tree write: out_pos(t, k) = (t >> 1) + (k + S * (t & 1)) * (T_curr >> 1)
+ // Lands in next-level strided read at index (t >> 1) with slot
+ // (k + S * (t & 1)).
+ let t_next = t >> 1u;
+ let slot_in_next = k + S * (t & 1u);
+ let T_next = T_curr >> 1u;
+ let plane_base = plane * PG * N_out;
+ let elem = t_next + slot_in_next * T_next;
+ let base = plane_base + PG * elem;
+ let w = pack_limbs_to_256(val);
+ outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]);
+ outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]);
+}
+
+fn store_out_simple(plane: u32, t: u32, k: u32, T_curr: u32, N_out: u32, val: ptr) {
+ // Final-level simple strided write: out_pos(t, k) = t + k * T_curr.
+ // Used when there is no next pair-tree level (T_curr == 1 thread, or
+ // the host indicates this is the last reduction step).
+ let plane_base = plane * PG * N_out;
+ let elem = t + k * T_curr;
+ let base = plane_base + PG * elem;
+ let w = pack_limbs_to_256(val);
+ outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]);
+ outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]);
+}
+
+fn get_r() -> BigInt {
+ var r: BigInt;
+{{{ r_limbs }}}
+ return r;
+}
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let N_in = params.x;
+ let T_curr = params.y;
+ let final_flag = params.z; // non-zero => use simple strided write
+ let N_out = N_in / 2u;
+
+ let t = gid.x;
+ if (t >= T_curr) { return; }
+
+ var pref: array;
+ var acc: BigInt = get_r();
+ for (var k: u32 = 0u; k < S; k = k + 1u) {
+ var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in);
+ var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in);
+ var dx: BigInt = fr_sub(&p_rx, &p_lx);
+ if (k == 0u) {
+ acc = dx;
+ } else {
+ acc = montgomery_product(&acc, &dx);
+ }
+ pref[k] = acc;
+ }
+
+ var inv: BigInt = fr_inv_by_a(acc);
+
+ for (var jj: u32 = 0u; jj < S; jj = jj + 1u) {
+ let k = S - 1u - jj;
+
+ var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in);
+ var p_ly: BigInt = load_in(1u, t, 2u * k + 0u, T_curr, N_in);
+ var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in);
+ var p_ry: BigInt = load_in(1u, t, 2u * k + 1u, T_curr, N_in);
+
+ var inv_dx: BigInt;
+ if (k == 0u) {
+ inv_dx = inv;
+ } else {
+ var pp = pref[k - 1u];
+ inv_dx = montgomery_product(&inv, &pp);
+ }
+
+ var lambda: BigInt = fr_sub(&p_ry, &p_ly);
+ lambda = montgomery_product(&lambda, &inv_dx);
+ var r_x: BigInt = montgomery_product(&lambda, &lambda);
+ r_x = fr_sub(&r_x, &p_lx);
+ r_x = fr_sub(&r_x, &p_rx);
+ var r_y: BigInt = fr_sub(&p_lx, &r_x);
+ r_y = montgomery_product(&lambda, &r_y);
+ r_y = fr_sub(&r_y, &p_ly);
+
+ if (final_flag != 0u) {
+ store_out_simple(0u, t, k, T_curr, N_out, &r_x);
+ store_out_simple(1u, t, k, T_curr, N_out, &r_y);
+ } else {
+ store_out_tree(0u, t, k, T_curr, N_out, &r_x);
+ store_out_tree(1u, t, k, T_curr, N_out, &r_y);
+ }
+
+ if (k > 0u) {
+ var dx_back: BigInt = fr_sub(&p_rx, &p_lx);
+ inv = montgomery_product(&inv, &dx_back);
+ }
+ }
+
+ {{{ recompile }}}
+}
+`;
+
+export const ba_pair_disjoint_tree_prod = `{{> structs }}
+{{> bigint_funcs }}
+{{> montgomery_product_funcs }}
+{{> field_funcs }}
+{{> fr_pow_funcs }}
+{{> bigint_by_funcs }}
+{{> by_inverse_a_funcs }}
+
+{{{ dec_unpack }}}
+
+{{{ dec_pack }}}
+
+// Disjoint pair-sum kernel — prod variant for the v2 pair-tree
+// integration. Same disjoint pair-sum math as
+// ba_pair_disjoint_tree_bench (suffix-product single fr_inv_by_a per
+// chunk + lean affine add); the per-level T (= num_chunks) is read
+// from the planner's totals[3] storage output and the dispatch happens
+// indirectly so only real chunks run. Always uses the final-mode
+// strided write (matches what ba_scatter_pairs_prod expects).
+//
+// LAYOUT: same as the bench variant. Combined-SoA input/output (2
+// planes, PG=2 vec4 per element, plane-major then element-major then
+// vec4 within an element).
+
+const S: u32 = {{ s }}u;
+const PG: u32 = 2u;
+
+@group(0) @binding(0) var inp: array>;
+@group(0) @binding(1) var unused: array>;
+@group(0) @binding(2) var outp: array>;
+@group(0) @binding(3) var totals: array;
+
+fn load_in(plane: u32, t: u32, i: u32, T: u32, N_in: u32) -> BigInt {
+ let plane_base = plane * PG * N_in;
+ let base = plane_base + PG * (t + i * T);
+ let q0 = inp[base + 0u];
+ let q1 = inp[base + 1u];
+ var w: array;
+ w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w;
+ w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w;
+ return unpack256_to_limbs(w);
+}
+
+fn store_out_simple(plane: u32, t: u32, k: u32, T_curr: u32, N_out: u32, val: ptr) {
+ let plane_base = plane * PG * N_out;
+ let elem = t + k * T_curr;
+ let base = plane_base + PG * elem;
+ let w = pack_limbs_to_256(val);
+ outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]);
+ outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]);
+}
+
+fn get_r() -> BigInt {
+ var r: BigInt;
+{{{ r_limbs }}}
+ return r;
+}
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let T_curr = totals[3];
+ let N_in = 2u * S * T_curr;
+ let N_out = S * T_curr;
+
+ let t = gid.x;
+ if (t >= T_curr) { return; }
+
+ var pref: array;
+ var acc: BigInt = get_r();
+ for (var k: u32 = 0u; k < S; k = k + 1u) {
+ var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in);
+ var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in);
+ var dx: BigInt = fr_sub(&p_rx, &p_lx);
+ if (k == 0u) {
+ acc = dx;
+ } else {
+ acc = montgomery_product(&acc, &dx);
+ }
+ pref[k] = acc;
+ }
+
+ var inv: BigInt = fr_inv_by_a(acc);
+
+ for (var jj: u32 = 0u; jj < S; jj = jj + 1u) {
+ let k = S - 1u - jj;
+
+ var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in);
+ var p_ly: BigInt = load_in(1u, t, 2u * k + 0u, T_curr, N_in);
+ var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in);
+ var p_ry: BigInt = load_in(1u, t, 2u * k + 1u, T_curr, N_in);
+
+ var inv_dx: BigInt;
+ if (k == 0u) {
+ inv_dx = inv;
+ } else {
+ var pp = pref[k - 1u];
+ inv_dx = montgomery_product(&inv, &pp);
+ }
+
+ var lambda: BigInt = fr_sub(&p_ry, &p_ly);
+ lambda = montgomery_product(&lambda, &inv_dx);
+ var r_x: BigInt = montgomery_product(&lambda, &lambda);
+ r_x = fr_sub(&r_x, &p_lx);
+ r_x = fr_sub(&r_x, &p_rx);
+ var r_y: BigInt = fr_sub(&p_lx, &r_x);
+ r_y = montgomery_product(&lambda, &r_y);
+ r_y = fr_sub(&r_y, &p_ly);
+
+ store_out_simple(0u, t, k, T_curr, N_out, &r_x);
+ store_out_simple(1u, t, k, T_curr, N_out, &r_y);
+
+ if (k > 0u) {
+ var dx_back: BigInt = fr_sub(&p_rx, &p_lx);
+ inv = montgomery_product(&inv, &dx_back);
+ }
+ }
+
+ {{{ recompile }}}
+}
+`;
+
+export const ba_planner_bench = `{{> structs }}
+
+// GPU-side bin-packing planner for the v3 MSM bucket-accumulate
+// pipeline. One thread per bucket; uses atomicAdd to reserve global
+// per-pair slots in chunk_plan / scatter_plan and per-carry slots in
+// carry_plan, then writes that bucket's entries.
+//
+// Inputs (per current level):
+// counts: array per-bucket active count
+// offsets: array per-bucket starting index in active_sums_old
+//
+// Outputs (filled in by this kernel for the current level):
+// chunk_plan: array 2 u32 per (chunk_id, slot) — pair operand indices
+// scatter_plan: array 1 u32 per (chunk_id, slot) — destination in active_sums_new
+// carry_plan: array 2 u32 per carry slot — (src in old, dst in new)
+// totals: array> [0]=total pairs, [1]=total carries, [2]=total new actives
+// new_counts: array per-bucket new active count (for next level)
+// new_offsets: array per-bucket new offset in active_sums_new (for next level)
+//
+// Convention: discard slot = M_new - 1 (the highest index in
+// active_sums_new). Pad pair source indices = (pad_l_idx, pad_r_idx)
+// supplied via params. All non-real chunk_plan / scatter_plan slots
+// must be pre-padded to (pad_l_idx, pad_r_idx) and discard_idx by the
+// host before each planner dispatch.
+//
+// params.x = B (bucket count)
+// params.y = S (chunk size, slots per chunk)
+// (pad_l_idx / pad_r_idx / discard_idx live in the pre-padded
+// arrays, not in params)
+
+const S: u32 = {{ s }}u;
+
+@group(0) @binding(0) var counts: array;
+@group(0) @binding(1) var offsets: array;
+@group(0) @binding(2) var chunk_plan: array;
+@group(0) @binding(3) var scatter_plan: array;
+@group(0) @binding(4) var carry_plan: array;
+@group(0) @binding(5) var totals: array>;
+@group(0) @binding(6) var new_counts: array;
+@group(0) @binding(7) var new_offsets: array;
+@group(0) @binding(8) var params: vec4;
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let B = params.x;
+ let b = gid.x;
+ if (b >= B) { return; }
+
+ let n = counts[b];
+ let pair_count = n / 2u;
+ let carry_flag = n & 1u;
+ let nc = pair_count + carry_flag;
+ new_counts[b] = nc;
+
+ // Atomic offset reservation. Each bucket gets a unique non-overlapping
+ // range in the global arrays. Atomic order is non-deterministic but
+ // that's fine: bucket b records its assigned offsets and uses them
+ // consistently for its own chunk_plan / scatter_plan / new_offsets
+ // writes. Different buckets land in different ranges by construction.
+ let my_pair_off = atomicAdd(&totals[0u], pair_count);
+ let my_carry_off = atomicAdd(&totals[1u], carry_flag);
+ let my_new_off = atomicAdd(&totals[2u], nc);
+ new_offsets[b] = my_new_off;
+
+ let bucket_base = offsets[b];
+
+ // Write this bucket's pair entries into chunk_plan / scatter_plan.
+ // Loop bounded by pair_count (variable per bucket; typically ~16
+ // for Poisson(λ=32)). The TAIL_CAP-style compile-time bound used
+ // by ba_tail_reduce isn't strictly needed here since this kernel
+ // doesn't do field arithmetic; the loop is plain integer writes.
+ // We still bound it by a compile-time constant for WGSL static
+ // analysis purposes.
+ let PAIR_CAP: u32 = {{ pair_cap }}u;
+ for (var j: u32 = 0u; j < PAIR_CAP; j = j + 1u) {
+ if (j >= pair_count) { break; }
+ let global_slot = my_pair_off + j;
+ let chunk_id = global_slot / S;
+ let slot_in_chunk = global_slot % S;
+ let cp_base = 2u * (chunk_id * S + slot_in_chunk);
+ chunk_plan[cp_base + 0u] = bucket_base + 2u * j;
+ chunk_plan[cp_base + 1u] = bucket_base + 2u * j + 1u;
+ scatter_plan[chunk_id * S + slot_in_chunk] = my_new_off + j;
+ }
+
+ if (carry_flag != 0u) {
+ let cs = my_carry_off;
+ carry_plan[2u * cs + 0u] = bucket_base + n - 1u;
+ carry_plan[2u * cs + 1u] = my_new_off + pair_count;
+ }
+
+ {{{ recompile }}}
+}
+`;
+
+export const ba_planner_v2_bench = `{{> structs }}
+
+// Optimal single-kernel GPU bin-packing planner for the MSM
+// bucket-accumulate pair-tree.
+//
+// One workgroup of TPB threads processes B buckets. Each thread
+// handles PER_THREAD = B / TPB buckets via a contiguous slice
+// [tid * PER_THREAD, (tid+1) * PER_THREAD).
+//
+// Phase A — Per-thread local scan
+// For each of its PER_THREAD buckets, compute (pair_count, carry_flag,
+// new_count). Accumulate per-thread totals (sum across the thread's
+// slice). Keep the per-bucket triples in registers; we will re-scan
+// them in Phase B.
+//
+// Phase B — Workgroup-wide Hillis-Steele scan (3 in parallel)
+// Scan the per-thread totals for pair, carry, new across the TPB
+// threads in shared memory. Result: each thread gets the global
+// prefix sum at the START of its slice (= base offset for its first
+// bucket).
+//
+// Phase C — Per-thread scatter
+// For each bucket in the thread's slice (in order), use the running
+// thread-local offset to compute global pair_offset_b and write the
+// pair_count[b] chunk_plan entries plus the (optional) carry_plan
+// entry. Update local running offsets. Write new_counts[b] and
+// new_offsets[b] for the next level.
+//
+// Phase D — One thread writes totals.
+// totals[0] = total_pairs, totals[1] = total_carries,
+// totals[2] = total_new_actives.
+//
+// Single dispatch. No atomics. No host sync. Scales to B = TPB *
+// PER_THREAD (e.g. 256 * 32 = 8192) within one workgroup. Larger B
+// requires multi-workgroup scan + global combine (out of scope here).
+//
+// Compile-time constants:
+// TPB : workgroup size (e.g. 256)
+// PER_THREAD : buckets per thread (e.g. 16 for B=4096, 32 for B=8192)
+// PAIR_CAP : bound on per-bucket pair count (Poisson(λ=32) tail
+// is ~30; choose 64 for safety)
+// S : chunk size in pairs (e.g. 16)
+
+const TPB: u32 = {{ workgroup_size }}u;
+const PER_THREAD: u32 = {{ per_thread }}u;
+const PAIR_CAP: u32 = {{ pair_cap }}u;
+const S: u32 = {{ s }}u;
+
+@group(0) @binding(0) var counts: array;
+@group(0) @binding(1) var offsets: array;
+@group(0) @binding(2) var chunk_plan: array;
+@group(0) @binding(3) var scatter_plan: array;
+@group(0) @binding(4) var carry_plan: array;
+@group(0) @binding(5) var new_counts: array;
+@group(0) @binding(6) var new_offsets: array;
+@group(0) @binding(7) var totals: array;
+@group(0) @binding(8) var params: vec4;
+// params.x = B
+
+// Workgroup-shared running prefixes for the 3 scans.
+var pair_scan: array;
+var carry_scan: array;
+var new_scan: array;
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(local_invocation_id) lid: vec3) {
+ let tid = lid.x;
+ let B = params.x;
+
+ // Phase A: per-thread local read + accumulate.
+ // Keep PER_THREAD bucket triples in registers (small array).
+ var local_pc: array;
+ var local_cf: array;
+ var local_nc: array;
+ var sum_p: u32 = 0u;
+ var sum_c: u32 = 0u;
+ var sum_n: u32 = 0u;
+ for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) {
+ let b = tid * PER_THREAD + k;
+ var pc: u32 = 0u;
+ var cf: u32 = 0u;
+ var nc: u32 = 0u;
+ if (b < B) {
+ let n = counts[b];
+ pc = n / 2u;
+ cf = n & 1u;
+ nc = pc + cf;
+ }
+ local_pc[k] = pc;
+ local_cf[k] = cf;
+ local_nc[k] = nc;
+ sum_p += pc;
+ sum_c += cf;
+ sum_n += nc;
+ }
+
+ // Phase B: workgroup-wide Hillis-Steele inclusive scan over per-
+ // thread totals (3 scans interleaved).
+ pair_scan[tid] = sum_p;
+ carry_scan[tid] = sum_c;
+ new_scan[tid] = sum_n;
+ workgroupBarrier();
+ for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) {
+ var add_p: u32 = 0u;
+ var add_c: u32 = 0u;
+ var add_n: u32 = 0u;
+ if (tid >= stride) {
+ add_p = pair_scan[tid - stride];
+ add_c = carry_scan[tid - stride];
+ add_n = new_scan[tid - stride];
+ }
+ workgroupBarrier();
+ if (tid >= stride) {
+ pair_scan[tid] = pair_scan[tid] + add_p;
+ carry_scan[tid] = carry_scan[tid] + add_c;
+ new_scan[tid] = new_scan[tid] + add_n;
+ }
+ workgroupBarrier();
+ }
+ // pair_scan[tid] is now inclusive prefix. Exclusive base = inclusive - own_sum.
+ var local_pair_off: u32 = pair_scan[tid] - sum_p;
+ var local_carry_off: u32 = carry_scan[tid] - sum_c;
+ var local_new_off: u32 = new_scan[tid] - sum_n;
+
+ // Phase D: thread 0 writes totals (using the FINAL inclusive scan).
+ if (tid == TPB - 1u) {
+ totals[0] = pair_scan[tid];
+ totals[1] = carry_scan[tid];
+ totals[2] = new_scan[tid];
+ }
+
+ // Phase C: per-thread scatter.
+ for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) {
+ let b = tid * PER_THREAD + k;
+ if (b >= B) { break; }
+
+ let pc = local_pc[k];
+ let cf = local_cf[k];
+ let nc = local_nc[k];
+ new_counts[b] = nc;
+ new_offsets[b] = local_new_off;
+
+ let bucket_base = offsets[b];
+
+ // Pair entries: bounded loop, break at pc.
+ for (var j: u32 = 0u; j < PAIR_CAP; j = j + 1u) {
+ if (j >= pc) { break; }
+ let global_slot = local_pair_off + j;
+ let chunk_id = global_slot / S;
+ let slot_in_chunk = global_slot % S;
+ let cp_base = 2u * (chunk_id * S + slot_in_chunk);
+ chunk_plan[cp_base + 0u] = bucket_base + 2u * j;
+ chunk_plan[cp_base + 1u] = bucket_base + 2u * j + 1u;
+ scatter_plan[chunk_id * S + slot_in_chunk] = local_new_off + j;
+ }
+
+ // Carry entry (if odd count).
+ if (cf != 0u) {
+ let cs = local_carry_off;
+ carry_plan[2u * cs + 0u] = bucket_base + counts[b] - 1u;
+ carry_plan[2u * cs + 1u] = local_new_off + pc;
+ }
+
+ local_pair_off += pc;
+ local_carry_off += cf;
+ local_new_off += nc;
+ }
+
+ {{{ recompile }}}
+}
+`;
+
+export const ba_planner_v2_prod = `{{> structs }}
+
+// Production GPU bin-packing planner for the v2 pair-tree integration.
+//
+// Same algorithm as ba_planner_v2_bench (one workgroup of TPB threads,
+// per-thread local scan, workgroup-wide Hillis-Steele scan over the
+// three running sums, per-thread scatter) but extends the totals
+// output with the indirect-dispatch counts the production marshal /
+// disjoint / scatter / carry kernels need:
+//
+// totals[0] = total_pairs
+// totals[1] = total_carries
+// totals[2] = total_new
+// totals[3] = num_chunks = max(1, (total_pairs + S - 1) / S)
+// totals[4] = marshal/disjoint/scatter dispatch X (= ceil(num_chunks / WGI))
+// totals[5] = 1
+// totals[6] = 1
+// totals[7] = carry dispatch X (= ceil(total_carries / WGI))
+// totals[8] = 1
+// totals[9] = 1
+//
+// The four prod-variant downstream kernels (ba_marshal_pairs_prod,
+// ba_pair_disjoint_tree_prod, ba_scatter_pairs_prod, ba_carry_copy_prod)
+// read num_chunks and total_carries from this same totals storage
+// buffer so a single planner dispatch fully drives the level's runtime
+// shape with zero wasted-pad-chunk compute. The host orchestrator
+// reuses the totals buffer as the indirect-dispatch source via
+// dispatchWorkgroupsIndirect(totals, 16) for marshal/disjoint/scatter
+// (totals u32 indices 4..6) and dispatchWorkgroupsIndirect(totals, 28)
+// for carry (totals u32 indices 7..9).
+//
+// Compile-time constants:
+// TPB : workgroup size (e.g. 256)
+// PER_THREAD : buckets per thread
+// PAIR_CAP : per-bucket pair-count bound
+// S : chunk size in pairs
+// WGI : downstream kernel workgroup size — must match the
+// workgroup_size of ba_marshal_pairs_prod /
+// ba_pair_disjoint_tree_prod / ba_scatter_pairs_prod /
+// ba_carry_copy_prod.
+
+const TPB: u32 = {{ workgroup_size }}u;
+const PER_THREAD: u32 = {{ per_thread }}u;
+const PAIR_CAP: u32 = {{ pair_cap }}u;
+const S: u32 = {{ s }}u;
+const WGI: u32 = {{ wgi }}u;
+
+@group(0) @binding(0) var counts: array;
+@group(0) @binding(1) var offsets: array;
+@group(0) @binding(2) var chunk_plan: array;
+@group(0) @binding(3) var scatter_plan: array;
+@group(0) @binding(4) var carry_plan: array;
+@group(0) @binding(5) var new_counts: array;
+@group(0) @binding(6) var new_offsets: array;
+@group(0) @binding(7) var totals: array;
+@group(0) @binding(8) var