diff --git a/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.html b/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.html new file mode 100644 index 000000000000..95dc4bee2cdb --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.html @@ -0,0 +1,37 @@ + + + + + Disjoint pair-sum batch-affine bench (WebGPU) + + + +

Disjoint pair-sum batch-affine bench (WebGPU)

+

Query params: ?reps=R&pairs=N&wgi=W&disp=D&s=A,B,C

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.ts b/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.ts new file mode 100644 index 000000000000..500fff0eba60 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-ba-pair-disjoint.ts @@ -0,0 +1,435 @@ +/// +// Standalone WebGPU bench for the disjoint pair-sum kernel +// (ba_pair_disjoint_bench): each thread reduces 2*S input points to +// S disjoint pair sums R_k = P_{2k} + P_{2k+1}, using one batched +// fr_inv_by_a per chunk of S. Same DISP=8 dispatch amortisation as +// bench-ba-rev-packed-carry to keep the measurement methodology +// apples-to-apples. +// +// Input: random Montgomery field elems packed into SoA layout with +// 2 planes (P.x, P.y), 2*PAIRS elements per plane. Pairs are arranged +// at strided positions e = t + i*T for i in 0..2S so the kernel's +// coalesced reads work. Adjacent pair members are guaranteed distinct +// x to avoid the lean-add div-by-zero. +// +// Output: 2 planes (R.x, R.y), PAIRS elements per plane. Reports +// ns/useful-pair-sum (every kernel output is a usable disjoint pair +// sum, vs the chain kernel where only S/2 of S outputs are usable). + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { makeResultsClient } from './results_post.js'; + +const PG = 2; +const DEFAULT_PAIRS = 1 << 17; // 131072 output pair sums per dispatch +const DEFAULT_WGI = 64; +const DEFAULT_DISP = 8; +const DEFAULT_S_SWEEP: readonly number[] = [16, 32, 64]; + +let PAIRS = DEFAULT_PAIRS; +let WGI = DEFAULT_WGI; +let DISP = DEFAULT_DISP; +let S_SWEEP: readonly number[] = DEFAULT_S_SWEEP; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + for (;;) { + let v = 0n; + for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff); + v &= (1n << BigInt(bitlen)) - 1n; + if (v > 0n && v < p) return v; + } +} + +function bigintToPackedU32x8(v: bigint): Uint32Array { + const w = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + w[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return w; +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +// SoA-packed input buffer with 2 planes (P.x, P.y), each PG*(2*PAIRS) +// vec4. Plane p at element idx e: vec4 indices (p*PG + v)*N_in + e for +// v in 0..PG, where N_in = 2*PAIRS. Adjacent pairs (2k, 2k+1) have +// distinct x. +function buildPackedPairsSoA(pairs: number, R: bigint, p: bigint, rng: () => number): Uint32Array { + const N_in = 2 * pairs; + const buf = new Uint32Array(2 * PG * N_in * 4); + for (let k = 0; k < pairs; k++) { + let lx: bigint; + let rx: bigint; + do { + lx = (randomBelow(p, rng) * R) % p; + rx = (randomBelow(p, rng) * R) % p; + } while (lx === rx); + const ly = (randomBelow(p, rng) * R) % p; + const ry = (randomBelow(p, rng) * R) % p; + const writeElem = (planeIdx: number, e: number, val: bigint) => { + const words = bigintToPackedU32x8(val); + for (let v = 0; v < PG; v++) { + const base = ((planeIdx * PG + v) * N_in + e) * 4; + buf[base + 0] = words[4 * v + 0]; + buf[base + 1] = words[4 * v + 1]; + buf[base + 2] = words[4 * v + 2]; + buf[base + 3] = words[4 * v + 3]; + } + }; + writeElem(0, 2 * k + 0, lx); + writeElem(1, 2 * k + 0, ly); + writeElem(0, 2 * k + 1, rx); + writeElem(1, 2 * k + 1, ry); + } + return buf; +} + +interface PerSizeResult { + s: number; + wgi: number; + T: number; + num_wgs: number; + pairs: number; + disp: number; + total_ops: number; + median_ms: number; + min_ms: number; + max_ms: number; + ns_per_op: number; + samples_ms: number[]; + sanity_ok: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { reps: number; pairs: number; wgi: number; disp: number; s_sweep: readonly number[] } | null; + results: PerSizeResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { + state: 'boot', + params: null, + results: [], + error: null, + log: [], +}; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const resultsClient = makeResultsClient({ page: 'bench-ba-pair-disjoint' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, + params: benchState.params, + results: benchState.results, + error: benchState.error, + log: benchState.log, + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-ba-pair-disjoint] ${msg}`); +} + +async function compile( + device: GPUDevice, + code: string, + cacheKey: string, +): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const msg of info.messages) { + const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`; + if (msg.type === 'error') { + console.error(line); + log('err', line); + errLines.push(line); + hasError = true; + } else { + console.warn(line); + log('warn', line); + } + } + if (hasError) throw new Error(`WGSL compile failed for ${cacheKey}: ${errLines.join(' | ')}`); + const layout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const pipeline = await device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); + return { pipeline, layout }; +} + +async function readNonZero(device: GPUDevice, buf: GPUBuffer, u32Count: number): Promise { + const bytes = u32Count * 4; + const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, bytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const u32 = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + for (let i = 0; i < u32.length; i++) if (u32[i] !== 0) return true; + return false; +} + +async function timeDispatch( + device: GPUDevice, + pipeline: GPUComputePipeline, + bind: GPUBindGroup, + numWgs: number, + reps: number, + passes: number, +): Promise { + { + const enc = device.createCommandEncoder(); + for (let pIdx = 0; pIdx < passes; pIdx++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + } + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + const samples: number[] = []; + for (let r = 0; r < reps; r++) { + const enc = device.createCommandEncoder(); + for (let pIdx = 0; pIdx < passes; pIdx++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + samples.push(performance.now() - t0); + } + return samples; +} + +async function runOne( + device: GPUDevice, + sm: ShaderManager, + s: number, + reps: number, + R: bigint, + p: bigint, + seed: number, +): Promise { + if (PAIRS % s !== 0) throw new Error(`PAIRS=${PAIRS} must be a multiple of S=${s}`); + const T = PAIRS / s; + const numWgs = Math.ceil(T / WGI); + log('info', `=== S=${s}: PAIRS=${PAIRS} T=${T} WGI=${WGI} numWgs=${numWgs} DISP=${DISP}`); + + const code = sm.gen_ba_pair_disjoint_bench_shader(WGI, s); + log('info', `compiling shader (${code.length} chars)`); + (window as unknown as Record)[`__shader_s${s}`] = code; + const cacheKey = `bench-ba-pair-disjoint-W${WGI}-S${s}`; + const { pipeline, layout } = await compile(device, code, cacheKey); + + const rng = makeRng(seed); + const inU32 = buildPackedPairsSoA(PAIRS, R, p, rng); + + const inBuf = device.createBuffer({ + size: inU32.byteLength, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(inBuf, 0, inU32); + const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE }); + const outBytes = 2 * PG * PAIRS * 4 * 4; + const outBuf = device.createBuffer({ + size: outBytes, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, + }); + const paramsBuf = device.createBuffer({ + size: 16, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(paramsBuf, 0, new Uint32Array([2 * PAIRS, T, 0, 0])); + + const bind = device.createBindGroup({ + layout, + entries: [ + { binding: 0, resource: { buffer: inBuf } }, + { binding: 1, resource: { buffer: dummy } }, + { binding: 2, resource: { buffer: outBuf } }, + { binding: 3, resource: { buffer: paramsBuf } }, + ], + }); + + const samples = await timeDispatch(device, pipeline, bind, numWgs, reps, DISP); + const sanityOk = await readNonZero(device, outBuf, 8); + const med = median(samples); + const totalOps = PAIRS * DISP; + const nsPerOp = (med * 1e6) / totalOps; + + log( + sanityOk ? 'ok' : 'err', + `S=${s}: median=${med.toFixed(3)}ms min=${Math.min(...samples).toFixed(3)}ms max=${Math.max(...samples).toFixed(3)}ms ns/op=${nsPerOp.toFixed(2)} sanity=${sanityOk ? 'OK' : 'FAIL'}`, + ); + + inBuf.destroy(); + dummy.destroy(); + outBuf.destroy(); + paramsBuf.destroy(); + + return { + s, + wgi: WGI, + T, + num_wgs: numWgs, + pairs: PAIRS, + disp: DISP, + total_ops: totalOps, + median_ms: med, + min_ms: Math.min(...samples), + max_ms: Math.max(...samples), + ns_per_op: nsPerOp, + samples_ms: samples, + sanity_ok: sanityOk, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + const reps = parseInt(qp.get('reps') ?? '5', 10); + if (!Number.isFinite(reps) || reps <= 0 || reps > 50) throw new Error(`?reps must be in (0, 50]`); + const pairsStr = qp.get('pairs'); + if (pairsStr !== null) { + const v = parseInt(pairsStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > (1 << 20)) throw new Error(`?pairs must be in (0, 2^20]`); + PAIRS = v; + } + const wgiStr = qp.get('wgi'); + if (wgiStr !== null) { + const v = parseInt(wgiStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > 1024) throw new Error(`?wgi must be in (0, 1024]`); + WGI = v; + } + const dispStr = qp.get('disp'); + if (dispStr !== null) { + const v = parseInt(dispStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > 64) throw new Error(`?disp must be in (0, 64]`); + DISP = v; + } + const sStr = qp.get('s'); + if (sStr !== null) { + const list = sStr.split(',').map(v => parseInt(v, 10)); + for (const v of list) { + if (!Number.isFinite(v) || v <= 0 || v > 256) throw new Error(`?s entries must be in (0, 256]`); + } + S_SWEEP = list; + } + for (const v of S_SWEEP) { + if (PAIRS % v !== 0) throw new Error(`S=${v} does not divide PAIRS=${PAIRS}`); + } + return { reps, pairs: PAIRS, wgi: WGI, disp: DISP, s_sweep: S_SWEEP }; +} + +async function main() { + try { + if (!('gpu' in navigator)) throw new Error('navigator.gpu missing — WebGPU not available'); + const params = parseParams(); + benchState.params = params; + log( + 'info', + `params: reps=${params.reps} pairs=${params.pairs} wgi=${params.wgi} disp=${params.disp} s=[${params.s_sweep.join(',')}]`, + ); + + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + + const sm = new ShaderManager(4, PAIRS, BN254_CURVE_CONFIG, false); + + let seed = 0xd1d1; + for (const s of S_SWEEP) { + try { + const r = await runOne(device, sm, s, params.reps, R, p, seed); + benchState.results.push(r); + resultsClient.postProgress({ + kind: 'batch_done', + s, + median_ms: r.median_ms, + ns_per_op: r.ns_per_op, + sanity_ok: r.sanity_ok, + }); + seed += 0x10; + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + log('err', `S=${s} failed: ${msg} — STOPPING`); + benchState.state = 'error'; + benchState.error = msg; + return; + } + } + + benchState.state = 'done'; + log('ok', 'all sizes done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + }) + .finally(() => { + postFinal().catch(() => {}); + }); diff --git a/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.html b/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.html new file mode 100644 index 000000000000..6676d4772bd2 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.html @@ -0,0 +1,37 @@ + + + + + ba_rev_packed_carry standalone batch-affine bench (WebGPU) + + + +

ba_rev_packed_carry standalone batch-affine bench (WebGPU)

+

Query params: ?reps=R&pairs=N&wgi=W&disp=D&s=A,B,C

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.ts b/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.ts new file mode 100644 index 000000000000..2e55c2c9b7e4 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-ba-rev-packed-carry.ts @@ -0,0 +1,438 @@ +/// +// Standalone WebGPU bench for the canonical ba_rev_packed_carry kernel +// (recovered from commit eab3a3e). SoA-packed 8x u32 storage across 4 +// input planes (A.x, A.y, P.x, P.y), strided per-thread access +// e = t + i*T, single fr_inv_by_a per S-chunk, lean affine apply with +// resident-accumulator load-carry. Each timed sample submits DISP back- +// to-back dispatches in one command encoder to amortise submit+drain. +// +// Math: bucket-accumulate streaming chain (R_i = A_i + P_i where A_0 is +// the seed, A_{i+1} := P_i). With random P.x != A.x inputs, every dx is +// nonzero so the batched inverse is well-defined. Sanity check reads +// the first packed elem of R.x and asserts non-zero. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { makeResultsClient } from './results_post.js'; + +const DEFAULT_PAIRS = 1 << 17; // 131072 +const DEFAULT_WGI = 64; +const DEFAULT_DISP = 8; +const DEFAULT_S_SWEEP: readonly number[] = [16, 32, 64]; +const PG = 2; // 8 packed u32 / 4 = 2 vec4 groups per element + +let PAIRS = DEFAULT_PAIRS; +let WGI = DEFAULT_WGI; +let DISP = DEFAULT_DISP; +let S_SWEEP: readonly number[] = DEFAULT_S_SWEEP; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + for (;;) { + let v = 0n; + for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff); + v &= (1n << BigInt(bitlen)) - 1n; + if (v > 0n && v < p) return v; + } +} + +function bigintToPackedU32x8(v: bigint): Uint32Array { + const w = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + w[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return w; +} + +// SoA layout: 4 planes (A.x, A.y, P.x, P.y), each plane = PG * N vec4. +// For plane c and pair e, vec4 indices = (c * PG + v) * N + e for v in +// [0, PG). The flat Uint32Array index is that times 4. +function packAffineSoAPacked(pairs: number, R: bigint, p: bigint, rng: () => number): Uint32Array { + const N = pairs; + const buf = new Uint32Array(4 * PG * N * 4); + for (let e = 0; e < N; e++) { + let pxM: bigint; + let qxM: bigint; + do { + pxM = (randomBelow(p, rng) * R) % p; + qxM = (randomBelow(p, rng) * R) % p; + } while (pxM === qxM); + const coords = [pxM, (randomBelow(p, rng) * R) % p, qxM, (randomBelow(p, rng) * R) % p]; + for (let c = 0; c < 4; c++) { + const words = bigintToPackedU32x8(coords[c]); + for (let v = 0; v < PG; v++) { + const base = ((c * PG + v) * N + e) * 4; + buf[base + 0] = words[4 * v + 0]; + buf[base + 1] = words[4 * v + 1]; + buf[base + 2] = words[4 * v + 2]; + buf[base + 3] = words[4 * v + 3]; + } + } + } + return buf; +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +interface PerSizeResult { + s: number; + wgi: number; + T: number; + num_wgs: number; + pairs: number; + disp: number; + total_ops: number; + median_ms: number; + min_ms: number; + max_ms: number; + ns_per_op: number; + samples_ms: number[]; + sanity_ok: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { reps: number; pairs: number; wgi: number; disp: number; s_sweep: readonly number[] } | null; + results: PerSizeResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { + state: 'boot', + params: null, + results: [], + error: null, + log: [], +}; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const resultsClient = makeResultsClient({ page: 'bench-ba-rev-packed-carry' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, + params: benchState.params, + results: benchState.results, + error: benchState.error, + log: benchState.log, + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-ba-rev-packed-carry] ${msg}`); +} + +async function createPipeline( + device: GPUDevice, + code: string, + cacheKey: string, +): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const msg of info.messages) { + const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`; + if (msg.type === 'error') { + console.error(line); + log('err', line); + errLines.push(line); + hasError = true; + } else { + console.warn(line); + log('warn', line); + } + } + if (hasError) { + throw new Error(`WGSL compile failed for ${cacheKey}: ${errLines.join(' | ')}`); + } + const layout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const pipeline = await device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); + return { pipeline, layout }; +} + +async function readNonZero(device: GPUDevice, out: GPUBuffer, u32Count: number): Promise { + const bytes = u32Count * 4; + const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(out, 0, staging, 0, bytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const u32 = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + for (let i = 0; i < u32.length; i++) { + if (u32[i] !== 0) return true; + } + return false; +} + +async function timeDispatch( + device: GPUDevice, + pipeline: GPUComputePipeline, + bind: GPUBindGroup, + numWgs: number, + reps: number, + passes: number, +): Promise { + // warmup + { + const enc = device.createCommandEncoder(); + for (let pIdx = 0; pIdx < passes; pIdx++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + } + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + const samples: number[] = []; + for (let r = 0; r < reps; r++) { + const enc = device.createCommandEncoder(); + for (let pIdx = 0; pIdx < passes; pIdx++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + samples.push(performance.now() - t0); + } + return samples; +} + +async function runOne( + device: GPUDevice, + sm: ShaderManager, + s: number, + reps: number, + R: bigint, + p: bigint, + seed: number, +): Promise { + if (PAIRS % s !== 0) { + throw new Error(`PAIRS=${PAIRS} must be a multiple of S=${s}`); + } + const T = PAIRS / s; + const numWgs = Math.ceil(T / WGI); + log('info', `=== S=${s}: PAIRS=${PAIRS} T=${T} WGI=${WGI} numWgs=${numWgs} DISP=${DISP}`); + + const code = sm.gen_ba_rev_packed_carry_bench_shader(WGI, s); + const cacheKey = `bench-ba-rev-packed-carry-W${WGI}-S${s}`; + log('info', `compiling shader (${code.length} chars)`); + (window as unknown as Record)[`__shader_s${s}`] = code; + const { pipeline, layout } = await createPipeline(device, code, cacheKey); + + const rng = makeRng(seed); + const inU32 = packAffineSoAPacked(PAIRS, R, p, rng); + const inBuf = device.createBuffer({ + size: inU32.byteLength, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(inBuf, 0, inU32); + const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE }); + const outBytes = 2 * PG * PAIRS * 4 * 4; + const outBuf = device.createBuffer({ + size: outBytes, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, + }); + const paramsBuf = device.createBuffer({ + size: 16, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(paramsBuf, 0, new Uint32Array([PAIRS, T, 0, 0])); + + const bind = device.createBindGroup({ + layout, + entries: [ + { binding: 0, resource: { buffer: inBuf } }, + { binding: 1, resource: { buffer: dummy } }, + { binding: 2, resource: { buffer: outBuf } }, + { binding: 3, resource: { buffer: paramsBuf } }, + ], + }); + + const samples = await timeDispatch(device, pipeline, bind, numWgs, reps, DISP); + const sanityOk = await readNonZero(device, outBuf, 8); + const med = median(samples); + const totalOps = PAIRS * DISP; + const nsPerOp = (med * 1e6) / totalOps; + + log( + sanityOk ? 'ok' : 'err', + `S=${s}: median=${med.toFixed(3)}ms min=${Math.min(...samples).toFixed(3)}ms max=${Math.max(...samples).toFixed(3)}ms ns/op=${nsPerOp.toFixed(2)} sanity=${sanityOk ? 'OK' : 'FAIL'}`, + ); + + inBuf.destroy(); + dummy.destroy(); + outBuf.destroy(); + paramsBuf.destroy(); + + return { + s, + wgi: WGI, + T, + num_wgs: numWgs, + pairs: PAIRS, + disp: DISP, + total_ops: totalOps, + median_ms: med, + min_ms: Math.min(...samples), + max_ms: Math.max(...samples), + ns_per_op: nsPerOp, + samples_ms: samples, + sanity_ok: sanityOk, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + const reps = parseInt(qp.get('reps') ?? '5', 10); + if (!Number.isFinite(reps) || reps <= 0 || reps > 50) { + throw new Error(`?reps must be in (0, 50], got ${qp.get('reps')}`); + } + const pairsStr = qp.get('pairs'); + if (pairsStr !== null) { + const v = parseInt(pairsStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > (1 << 20)) { + throw new Error(`?pairs must be in (0, 2^20], got ${pairsStr}`); + } + PAIRS = v; + } + const wgiStr = qp.get('wgi'); + if (wgiStr !== null) { + const v = parseInt(wgiStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > 1024) { + throw new Error(`?wgi must be in (0, 1024], got ${wgiStr}`); + } + WGI = v; + } + const dispStr = qp.get('disp'); + if (dispStr !== null) { + const v = parseInt(dispStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > 64) { + throw new Error(`?disp must be in (0, 64], got ${dispStr}`); + } + DISP = v; + } + const sStr = qp.get('s'); + if (sStr !== null) { + const list = sStr.split(',').map(v => parseInt(v, 10)); + for (const v of list) { + if (!Number.isFinite(v) || v <= 0 || v > 256) { + throw new Error(`?s entries must be in (0, 256], got ${v}`); + } + } + S_SWEEP = list; + } + for (const v of S_SWEEP) { + if (PAIRS % v !== 0) { + throw new Error(`S=${v} does not divide PAIRS=${PAIRS}`); + } + } + return { reps, pairs: PAIRS, wgi: WGI, disp: DISP, s_sweep: S_SWEEP }; +} + +async function main() { + try { + if (!('gpu' in navigator)) { + throw new Error('navigator.gpu missing — WebGPU not available'); + } + const params = parseParams(); + benchState.params = params; + log( + 'info', + `params: reps=${params.reps} pairs=${params.pairs} wgi=${params.wgi} disp=${params.disp} s=[${params.s_sweep.join(',')}]`, + ); + + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + + const sm = new ShaderManager(4, PAIRS, BN254_CURVE_CONFIG, false); + + let seed = 0x7b10; + for (const s of S_SWEEP) { + try { + const r = await runOne(device, sm, s, params.reps, R, p, seed); + benchState.results.push(r); + resultsClient.postProgress({ kind: 'batch_done', s, median_ms: r.median_ms, ns_per_op: r.ns_per_op, sanity_ok: r.sanity_ok }); + seed += 0x10; + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + log('err', `S=${s} failed: ${msg} — STOPPING sweep at first failure`); + benchState.state = 'error'; + benchState.error = msg; + return; + } + } + + benchState.state = 'done'; + log('ok', 'all sizes done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + }) + .finally(() => { + postFinal().catch(() => {}); + }); diff --git a/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.html b/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.html new file mode 100644 index 000000000000..31b300972d08 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.html @@ -0,0 +1,22 @@ + + + + + cuZK CSR -> v2 active_sums layout converter bench (WebGPU) + + + +

cuZK CSR -> v2 active_sums layout converter

+

Query params: ?subtasks=T&columns=C&input=N&wg=W&disp=D&reps=R&validate=1

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.ts b/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.ts new file mode 100644 index 000000000000..44a0458b96d1 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-csr-to-v2.ts @@ -0,0 +1,517 @@ +/// +// Standalone microbench + validator for the cuZK CSR -> v2 active_sums +// layout converter. The converter consumes the cuZK transpose output +// (val_idx + row_ptr) and the packed 8xu32 cached bases, and emits the +// bucket-major active_sums + per-bucket counts/offsets that the v2 +// pair-tree expects. +// +// Inputs are synthetic — random bucket assignments per (subtask, scalar) +// — so the harness validates the converter in isolation without needing +// the full cuZK convert + decompose + transpose pipeline. +// +// Validation: byte-equivalent compare of active_sums_x/y, active_counts, +// and active_offsets against a host-side reference. The reference is +// trivial: active_sums[k] = new_point[val_idx[k]], counts[b] = +// row_ptr[b+1]-row_ptr[b], offsets[b] = row_ptr[b]. +// +// Timing methodology mirrors bench-planner: warmup, then DISP back-to- +// back dispatches in one encoder, divided by DISP for per-call time. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { makeResultsClient } from './results_post.js'; + +let NUM_SUBTASKS = 16; +let NUM_COLUMNS = 4096; +let INPUT_SIZE = 1 << 14; +let WG = 64; +let DISP = 32; +let REPS = 5; +let VALIDATE = false; +let SEED = 0xc5af; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function buildSyntheticCsr( + numSubtasks: number, + numColumns: number, + inputSize: number, + seed: number, +): { rowPtr: Uint32Array; valIdx: Uint32Array } { + const rng = makeRng(seed); + const rowPtr = new Uint32Array(numSubtasks * (numColumns + 1)); + const valIdx = new Uint32Array(numSubtasks * inputSize); + + for (let s = 0; s < numSubtasks; s++) { + const bucketOf = new Uint32Array(inputSize); + const perBucketCount = new Uint32Array(numColumns); + for (let i = 0; i < inputSize; i++) { + const hi = (rng() >>> 16) & 0xffff; + const lo = (rng() >>> 16) & 0xffff; + const r = hi * 0x10000 + lo; + const b = r % numColumns; + bucketOf[i] = b; + perBucketCount[b]++; + } + const rpBase = s * (numColumns + 1); + rowPtr[rpBase] = 0; + for (let b = 0; b < numColumns; b++) { + rowPtr[rpBase + b + 1] = rowPtr[rpBase + b] + perBucketCount[b]; + } + const cursor = new Uint32Array(numColumns); + const viBase = s * inputSize; + for (let i = 0; i < inputSize; i++) { + const b = bucketOf[i]; + const slot = rowPtr[rpBase + b] + cursor[b]; + valIdx[viBase + slot] = i; + cursor[b]++; + } + } + return { rowPtr, valIdx }; +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +interface BenchResult { + num_subtasks: number; + num_columns: number; + input_size: number; + total_slots: number; + total_buckets: number; + wg: number; + disp: number; + reps: number; + active_sums_us: { min: number; median: number; max: number }; + meta_us: { min: number; median: number; max: number }; + combined_us: { min: number; median: number; max: number }; + ns_per_slot_combined: number; + validated: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: Record | null; + results: BenchResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] }; +(window as unknown as { __bench: BenchState }).__bench = benchState; +const resultsClient = makeResultsClient({ page: 'bench-csr-to-v2' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, + params: benchState.params, + results: benchState.results, + error: benchState.error, + log: benchState.log, + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-csr-to-v2] ${msg}`); +} + +async function compileOne( + device: GPUDevice, + code: string, + key: string, + layout: GPUBindGroupLayout, +): Promise { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const m of info.messages) { + const line = `[shader ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`; + if (m.type === 'error') { + console.error(line); + log('err', line); + errLines.push(line); + hasError = true; + } else { + console.warn(line); + } + } + if (hasError) throw new Error(`WGSL compile failed for ${key}: ${errLines.slice(0, 4).join(' | ')}`); + return device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); +} + +async function readbackU32(device: GPUDevice, buf: GPUBuffer, byteLength: number): Promise { + const staging = device.createBuffer({ size: byteLength, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, byteLength); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const out = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + return out; +} + +async function runOne(device: GPUDevice, sm: ShaderManager): Promise { + const totalSlots = NUM_SUBTASKS * INPUT_SIZE; + const totalBuckets = NUM_SUBTASKS * NUM_COLUMNS; + log( + 'info', + `=== T=${NUM_SUBTASKS} columns=${NUM_COLUMNS} input_size=${INPUT_SIZE} ` + + `total_slots=${totalSlots} total_buckets=${totalBuckets} WG=${WG} DISP=${DISP} REPS=${REPS}`, + ); + + // Synthetic packed bases: 32 bytes per element. Random bytes are fine + // — the converter is a layout-equivalent byte copy. + const rng = makeRng(SEED ^ 0xfeed); + const basesBytes = INPUT_SIZE * 32; + const basesX = new Uint8Array(basesBytes); + const basesY = new Uint8Array(basesBytes); + for (let i = 0; i < basesBytes; i += 4) { + const v = rng(); + basesX[i] = v & 0xff; + basesX[i + 1] = (v >>> 8) & 0xff; + basesX[i + 2] = (v >>> 16) & 0xff; + basesX[i + 3] = (v >>> 24) & 0xff; + const w = rng(); + basesY[i] = w & 0xff; + basesY[i + 1] = (w >>> 8) & 0xff; + basesY[i + 2] = (w >>> 16) & 0xff; + basesY[i + 3] = (w >>> 24) & 0xff; + } + log('info', `synthetic packed bases: ${(basesBytes * 2 / 1024).toFixed(1)} KiB total`); + + const { rowPtr, valIdx } = buildSyntheticCsr(NUM_SUBTASKS, NUM_COLUMNS, INPUT_SIZE, SEED); + log('info', `synthetic CSR: rowPtr=${(rowPtr.byteLength / 1024).toFixed(1)} KiB valIdx=${(valIdx.byteLength / 1024).toFixed(1)} KiB`); + + const mk = (bytes: number, copySrc = false, copyDst = false): GPUBuffer => { + let usage = GPUBufferUsage.STORAGE; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + return device.createBuffer({ size: bytes, usage }); + }; + + const M = totalSlots; + const activeBytes = 2 * 2 * M * 16; + const basesXBuf = mk(basesBytes, false, true); + const basesYBuf = mk(basesBytes, false, true); + const valIdxBuf = mk(valIdx.byteLength, false, true); + const rowPtrBuf = mk(rowPtr.byteLength, false, true); + const activeBuf = mk(activeBytes, true); + const countsBuf = mk(totalBuckets * 4, true); + const offsetsBuf = mk(totalBuckets * 4, true); + device.queue.writeBuffer(basesXBuf, 0, basesX); + device.queue.writeBuffer(basesYBuf, 0, basesY); + device.queue.writeBuffer(valIdxBuf, 0, valIdx); + device.queue.writeBuffer(rowPtrBuf, 0, rowPtr); + + const paramsActiveBuf = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(paramsActiveBuf, 0, new Uint32Array([totalSlots, M, 0, 0])); + const paramsMetaBuf = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(paramsMetaBuf, 0, new Uint32Array([NUM_COLUMNS, totalBuckets, 0, 0])); + + const activeLayout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const metaLayout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const activePipeline = await compileOne(device, sm.gen_csr_to_v2_active_sums_shader(WG), `csr-to-v2-active_sums-wg${WG}`, activeLayout); + const metaPipeline = await compileOne(device, sm.gen_csr_to_v2_meta_shader(WG), `csr-to-v2-meta-wg${WG}`, metaLayout); + + const activeBind = device.createBindGroup({ + layout: activeLayout, + entries: [ + { binding: 0, resource: { buffer: valIdxBuf } }, + { binding: 1, resource: { buffer: basesXBuf } }, + { binding: 2, resource: { buffer: basesYBuf } }, + { binding: 3, resource: { buffer: activeBuf } }, + { binding: 4, resource: { buffer: paramsActiveBuf } }, + ], + }); + const metaBind = device.createBindGroup({ + layout: metaLayout, + entries: [ + { binding: 0, resource: { buffer: rowPtrBuf } }, + { binding: 1, resource: { buffer: countsBuf } }, + { binding: 2, resource: { buffer: offsetsBuf } }, + { binding: 3, resource: { buffer: paramsMetaBuf } }, + ], + }); + + const groupsActive = Math.ceil(totalSlots / WG); + const groupsMeta = Math.ceil(totalBuckets / WG); + log('info', `dispatch shape: active=${groupsActive} wgs, meta=${groupsMeta} wgs`); + + // Warmup + { + const enc = device.createCommandEncoder(); + let pass = enc.beginComputePass(); + pass.setPipeline(activePipeline); + pass.setBindGroup(0, activeBind); + pass.dispatchWorkgroups(groupsActive, 1, 1); + pass.end(); + pass = enc.beginComputePass(); + pass.setPipeline(metaPipeline); + pass.setBindGroup(0, metaBind); + pass.dispatchWorkgroups(groupsMeta, 1, 1); + pass.end(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + log('info', 'warmup done'); + + let validated = false; + if (VALIDATE) { + const gpuActive = await readbackU32(device, activeBuf, activeBytes); + const gpuCounts = await readbackU32(device, countsBuf, totalBuckets * 4); + const gpuOffsets = await readbackU32(device, offsetsBuf, totalBuckets * 4); + const refBasesX = new Uint32Array(basesX.buffer, basesX.byteOffset, INPUT_SIZE * 8); + const refBasesY = new Uint32Array(basesY.buffer, basesY.byteOffset, INPUT_SIZE * 8); + + const mismatches: string[] = []; + let xFails = 0; + let yFails = 0; + const planeYU32Off = 2 * 2 * M; + for (let slot = 0; slot < totalSlots; slot++) { + const ptIdx = valIdx[slot]; + for (let w = 0; w < 8; w++) { + const got = gpuActive[slot * 8 + w]; + const want = refBasesX[ptIdx * 8 + w]; + if (got !== want) { + xFails++; + if (mismatches.length < 8) mismatches.push(`activeX[slot=${slot} w=${w}]: gpu=${got} ref=${want}`); + } + const gotY = gpuActive[planeYU32Off + slot * 8 + w]; + const wantY = refBasesY[ptIdx * 8 + w]; + if (gotY !== wantY) { + yFails++; + if (mismatches.length < 8) mismatches.push(`activeY[slot=${slot} w=${w}]: gpu=${gotY} ref=${wantY}`); + } + } + } + let mFails = 0; + for (let s = 0; s < NUM_SUBTASKS; s++) { + const rpBase = s * (NUM_COLUMNS + 1); + const ccBase = s * NUM_COLUMNS; + for (let b = 0; b < NUM_COLUMNS; b++) { + const wantCount = rowPtr[rpBase + b + 1] - rowPtr[rpBase + b]; + const wantOffset = rowPtr[rpBase + b]; + if (gpuCounts[ccBase + b] !== wantCount) { + mFails++; + if (mismatches.length < 12) mismatches.push(`counts[s=${s} b=${b}]: gpu=${gpuCounts[ccBase + b]} ref=${wantCount}`); + } + if (gpuOffsets[ccBase + b] !== wantOffset) { + mFails++; + if (mismatches.length < 12) mismatches.push(`offsets[s=${s} b=${b}]: gpu=${gpuOffsets[ccBase + b]} ref=${wantOffset}`); + } + } + } + if (xFails === 0 && yFails === 0 && mFails === 0) { + validated = true; + log('ok', `validation: PASS — converter output byte-equivalent to host reference (${totalSlots} slots, ${totalBuckets} buckets)`); + } else { + log('err', `validation: FAIL — activeX mismatches=${xFails}, activeY=${yFails}, meta=${mFails}`); + for (const m of mismatches.slice(0, 12)) log('err', ` ${m}`); + } + } + + // Timed: DISP back-to-back encoded as one submit per rep, split active vs meta. + const activeSamples: number[] = []; + const metaSamples: number[] = []; + const combinedSamples: number[] = []; + for (let r = 0; r < REPS; r++) { + { + const enc = device.createCommandEncoder(); + for (let d = 0; d < DISP; d++) { + const pass = enc.beginComputePass(); + pass.setPipeline(activePipeline); + pass.setBindGroup(0, activeBind); + pass.dispatchWorkgroups(groupsActive, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + activeSamples.push(performance.now() - t0); + } + { + const enc = device.createCommandEncoder(); + for (let d = 0; d < DISP; d++) { + const pass = enc.beginComputePass(); + pass.setPipeline(metaPipeline); + pass.setBindGroup(0, metaBind); + pass.dispatchWorkgroups(groupsMeta, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + metaSamples.push(performance.now() - t0); + } + { + const enc = device.createCommandEncoder(); + for (let d = 0; d < DISP; d++) { + let pass = enc.beginComputePass(); + pass.setPipeline(activePipeline); + pass.setBindGroup(0, activeBind); + pass.dispatchWorkgroups(groupsActive, 1, 1); + pass.end(); + pass = enc.beginComputePass(); + pass.setPipeline(metaPipeline); + pass.setBindGroup(0, metaBind); + pass.dispatchWorkgroups(groupsMeta, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + combinedSamples.push(performance.now() - t0); + } + } + const stat = (xs: number[]): { min: number; median: number; max: number } => { + const med = median(xs); + return { min: (Math.min(...xs) / DISP) * 1000, median: (med / DISP) * 1000, max: (Math.max(...xs) / DISP) * 1000 }; + }; + const activeStats = stat(activeSamples); + const metaStats = stat(metaSamples); + const combinedStats = stat(combinedSamples); + const nsPerSlotCombined = (combinedStats.median * 1000) / totalSlots; + log( + 'ok', + `active_sums per-dispatch: median=${activeStats.median.toFixed(2)}μs min=${activeStats.min.toFixed(2)}μs max=${activeStats.max.toFixed(2)}μs`, + ); + log( + 'ok', + `meta per-dispatch: median=${metaStats.median.toFixed(2)}μs min=${metaStats.min.toFixed(2)}μs max=${metaStats.max.toFixed(2)}μs`, + ); + log( + 'ok', + `combined per-dispatch: median=${combinedStats.median.toFixed(2)}μs min=${combinedStats.min.toFixed(2)}μs max=${combinedStats.max.toFixed(2)}μs ` + + `→ ${nsPerSlotCombined.toFixed(3)} ns/slot (${totalSlots} slots)`, + ); + + basesXBuf.destroy(); + basesYBuf.destroy(); + valIdxBuf.destroy(); + rowPtrBuf.destroy(); + activeBuf.destroy(); + countsBuf.destroy(); + offsetsBuf.destroy(); + paramsActiveBuf.destroy(); + paramsMetaBuf.destroy(); + + return { + num_subtasks: NUM_SUBTASKS, + num_columns: NUM_COLUMNS, + input_size: INPUT_SIZE, + total_slots: totalSlots, + total_buckets: totalBuckets, + wg: WG, + disp: DISP, + reps: REPS, + active_sums_us: activeStats, + meta_us: metaStats, + combined_us: combinedStats, + ns_per_slot_combined: nsPerSlotCombined, + validated, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + if (qp.get('subtasks')) NUM_SUBTASKS = parseInt(qp.get('subtasks')!, 10); + if (qp.get('columns')) NUM_COLUMNS = parseInt(qp.get('columns')!, 10); + if (qp.get('input')) INPUT_SIZE = parseInt(qp.get('input')!, 10); + if (qp.get('wg')) WG = parseInt(qp.get('wg')!, 10); + if (qp.get('disp')) DISP = parseInt(qp.get('disp')!, 10); + if (qp.get('reps')) REPS = parseInt(qp.get('reps')!, 10); + if (qp.get('seed')) SEED = parseInt(qp.get('seed')!, 10); + if (qp.get('validate') === '1') VALIDATE = true; + return { + subtasks: NUM_SUBTASKS, + columns: NUM_COLUMNS, + input_size: INPUT_SIZE, + wg: WG, + disp: DISP, + reps: REPS, + seed: SEED, + validate: VALIDATE, + }; +} + +async function main() { + try { + if (!('gpu' in navigator)) throw new Error('navigator.gpu missing'); + const params = parseParams(); + benchState.params = params; + log('info', `params: ${JSON.stringify(params)}`); + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + const sm = new ShaderManager(NUM_SUBTASKS, NUM_COLUMNS, BN254_CURVE_CONFIG, false); + const r = await runOne(device, sm); + benchState.results.push(r); + resultsClient.postProgress({ + kind: 'csr_to_v2_done', + active_sums_us: r.active_sums_us, + meta_us: r.meta_us, + combined_us: r.combined_us, + ns_per_slot_combined: r.ns_per_slot_combined, + validated: r.validated, + }); + benchState.state = 'done'; + log('ok', 'done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + }) + .finally(() => { + postFinal().catch(() => {}); + }); diff --git a/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.html b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.html new file mode 100644 index 000000000000..9605a79b5605 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.html @@ -0,0 +1,37 @@ + + + + + Workgroup-scan fused batch-affine round bench (WebGPU) + + + +

Workgroup-scan fused batch-affine round bench (WebGPU)

+

Query params: ?reps=R&total=N&sizes=A,B,C&skip_correctness=1

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts new file mode 100644 index 000000000000..264bbc0a6f77 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-fused-wg-scan.ts @@ -0,0 +1,549 @@ +/// +// Standalone WebGPU bench + correctness oracle for the new +// `batch_affine_fused_wg_scan` kernel (workgroup-scan fused batch-affine +// round, validated to mirror bench_batch_affine's 22 ns/pair design with +// bucket-indirect loads via pair_target_meta). +// +// Inputs: TOTAL_PAIRS on-curve BN254 G1 affine pairs (P_i, Q_i), +// distributed across ONE synthetic subtask. Each pair maps to a unique +// bucket (`pair_target_meta[2i] = i`, `pair_target_meta[2i+1] = i`, +// `val_idx[i] = i`), so the scheduler's "distinct-buckets within a +// subtask round" invariant holds trivially. The kernel writes the +// affine sum `R_i = P_i + Q_i` to running_x[i] / running_y[i]; we +// decode packed Mont form back to canonical and compare to noble's +// reference P.add(Q). +// +// SAFETY: NO MSM pipeline touched. The shader uses only compile-time- +// const loop bounds (BS, TPB, NUM_WORDS). One dispatch per measurement. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { makeResultsClient } from './results_post.js'; +import { bn254 } from '@noble/curves/bn254'; + +const G1 = bn254.G1.ProjectivePoint; +const FR_ORDER = bn254.fields.Fr.ORDER; + +const DEFAULT_TOTAL_PAIRS = 1 << 16; +let TOTAL_PAIRS = DEFAULT_TOTAL_PAIRS; + +const DEFAULT_BATCH_SIZES = [256, 512, 1024, 2048] as const; +let BATCH_SIZES: readonly number[] = DEFAULT_BATCH_SIZES; + +let SKIP_CORRECTNESS = false; + +function tpbBsFor(batchSize: number): { tpb: number; bs: number } { + if (batchSize <= 64) return { tpb: batchSize, bs: 1 }; + return { tpb: 64, bs: batchSize / 64 }; +} + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + for (;;) { + let v = 0n; + for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff); + v &= (1n << BigInt(bitlen)) - 1n; + if (v > 0n && v < p) return v; + } +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +function biToLe32u32(v: bigint): Uint32Array { + const out = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + out[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return out; +} + +function le32u32ToBi(u32: Uint32Array, off: number): bigint { + let v = 0n; + for (let i = 7; i >= 0; i--) v = (v << 32n) | BigInt(u32[off + i] >>> 0); + return v; +} + +interface PerSizeResult { + batch_size: number; + tpb: number; + bs: number; + num_wgs: number; + total_pairs: number; + median_ms: number; + min_ms: number; + max_ms: number; + ns_per_pair: number; + samples_ms: number[]; + correctness: 'pass' | 'fail' | 'skipped'; + correctness_first_fail?: { i: number; expected_x: string; got_x: string; expected_y: string; got_y: string }; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { reps: number; total: number; sizes: readonly number[]; skip_correctness: boolean } | null; + results: PerSizeResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { + state: 'boot', + params: null, + results: [], + error: null, + log: [], +}; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const resultsClient = makeResultsClient({ page: 'bench-fused-wg-scan' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, + params: benchState.params, + results: benchState.results, + error: benchState.error, + log: benchState.log, + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-fused-wg-scan] ${msg}`); +} + +async function createPipeline( + device: GPUDevice, + code: string, + cacheKey: string, +): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const msg of info.messages) { + const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`; + if (msg.type === 'error') { + console.error(line); + log('err', line); + errLines.push(line); + hasError = true; + } else { + console.warn(line); + log('warn', line); + } + } + if (hasError) { + throw new Error(`WGSL compile failed for ${cacheKey}: ${errLines.join(' | ')}`); + } + const layout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 5, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 6, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 7, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 8, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const pipeline = await device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); + return { pipeline, layout }; +} + +interface PointPair { + p: { x: bigint; y: bigint }; + q: { x: bigint; y: bigint }; + r: { x: bigint; y: bigint }; +} + +function buildPairs(n: number, seed: number): PointPair[] { + const rng = makeRng(seed); + const out: PointPair[] = []; + for (let i = 0; i < n; i++) { + let p: { x: bigint; y: bigint }; + let q: { x: bigint; y: bigint }; + let r: { x: bigint; y: bigint }; + for (;;) { + const sp = randomBelow(FR_ORDER, rng); + const sq = randomBelow(FR_ORDER, rng); + const pp = G1.BASE.multiply(sp); + const qp = G1.BASE.multiply(sq); + if (pp.is0() || qp.is0()) continue; + const pa = pp.toAffine(); + const qa = qp.toAffine(); + if (pa.x === qa.x) continue; + const rp = pp.add(qp); + if (rp.is0()) continue; + const ra = rp.toAffine(); + p = pa; + q = qa; + r = ra; + break; + } + out.push({ p, q, r }); + } + return out; +} + +async function runOne( + device: GPUDevice, + sm: ShaderManager, + batchSize: number, + reps: number, + R: bigint, + p: bigint, + pairs: PointPair[], +): Promise { + const { tpb, bs } = tpbBsFor(batchSize); + if (TOTAL_PAIRS % batchSize !== 0) { + throw new Error(`TOTAL_PAIRS=${TOTAL_PAIRS} must be a multiple of batch_size=${batchSize}`); + } + const numWgs = TOTAL_PAIRS / batchSize; + log('info', `=== batch_size=${batchSize}: TPB=${tpb} BS=${bs} num_WGs=${numWgs}`); + + const code = sm.gen_batch_affine_fused_wg_scan_shader(tpb, bs); + const cacheKey = `bench-fused-wg-scan-T${tpb}-S${bs}`; + log('info', `compiling shader (${code.length} chars)`); + (window as unknown as Record)[`__shader_${batchSize}`] = code; + const { pipeline, layout } = await createPipeline(device, code, cacheKey); + + const fieldBytes = 32; + const fieldsPerPair = 2; + const fieldsTotalIo = TOTAL_PAIRS * fieldsPerPair; + + const newPointXAB = new ArrayBuffer(TOTAL_PAIRS * fieldBytes); + const newPointYAB = new ArrayBuffer(TOTAL_PAIRS * fieldBytes); + const runningXAB = new ArrayBuffer(TOTAL_PAIRS * fieldBytes); + const runningYAB = new ArrayBuffer(TOTAL_PAIRS * fieldBytes); + + const nx32 = new Uint32Array(newPointXAB); + const ny32 = new Uint32Array(newPointYAB); + const rx32 = new Uint32Array(runningXAB); + const ry32 = new Uint32Array(runningYAB); + + for (let i = 0; i < TOTAL_PAIRS; i++) { + const { p: pp, q: qq } = pairs[i]; + const pxM = (pp.x * R) % p; + const pyM = (pp.y * R) % p; + const qxM = (qq.x * R) % p; + const qyM = (qq.y * R) % p; + rx32.set(biToLe32u32(pxM), i * 8); + ry32.set(biToLe32u32(pyM), i * 8); + nx32.set(biToLe32u32(qxM), i * 8); + ny32.set(biToLe32u32(qyM), i * 8); + } + + const valIdxAB = new Uint32Array(TOTAL_PAIRS); + const ptmAB = new Uint32Array(TOTAL_PAIRS * 2); + for (let i = 0; i < TOTAL_PAIRS; i++) { + valIdxAB[i] = i; + ptmAB[2 * i] = i; + ptmAB[2 * i + 1] = i; + } + const countAB = new Uint32Array([TOTAL_PAIRS]); + const paramsAB = new Uint32Array([TOTAL_PAIRS, TOTAL_PAIRS, 0, 0]); + + const mkSb = (size: number, copyDst = true, copySrc = false): GPUBuffer => { + let usage = GPUBufferUsage.STORAGE; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + return device.createBuffer({ size, usage }); + }; + + const valIdxBuf = mkSb(valIdxAB.byteLength); + const newPointXBuf = mkSb(newPointXAB.byteLength); + const newPointYBuf = mkSb(newPointYAB.byteLength); + const runningXBuf = mkSb(runningXAB.byteLength, true, true); + const runningYBuf = mkSb(runningYAB.byteLength, true, true); + const ptmBuf = mkSb(ptmAB.byteLength); + const prefixBuf = mkSb(TOTAL_PAIRS * 20 * 4, false); + const countBuf = mkSb(countAB.byteLength); + const paramsBuf = device.createBuffer({ + size: 16, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + + device.queue.writeBuffer(valIdxBuf, 0, valIdxAB); + device.queue.writeBuffer(newPointXBuf, 0, newPointXAB); + device.queue.writeBuffer(newPointYBuf, 0, newPointYAB); + device.queue.writeBuffer(ptmBuf, 0, ptmAB); + device.queue.writeBuffer(countBuf, 0, countAB); + device.queue.writeBuffer(paramsBuf, 0, paramsAB); + + const bindGroup = device.createBindGroup({ + layout, + entries: [ + { binding: 0, resource: { buffer: valIdxBuf } }, + { binding: 1, resource: { buffer: newPointXBuf } }, + { binding: 2, resource: { buffer: newPointYBuf } }, + { binding: 3, resource: { buffer: runningXBuf } }, + { binding: 4, resource: { buffer: runningYBuf } }, + { binding: 5, resource: { buffer: ptmBuf } }, + { binding: 6, resource: { buffer: prefixBuf } }, + { binding: 7, resource: { buffer: countBuf } }, + { binding: 8, resource: { buffer: paramsBuf } }, + ], + }); + + device.queue.writeBuffer(runningXBuf, 0, runningXAB); + device.queue.writeBuffer(runningYBuf, 0, runningYAB); + + const dispatch = async (resetState: boolean) => { + if (resetState) { + device.queue.writeBuffer(runningXBuf, 0, runningXAB); + device.queue.writeBuffer(runningYBuf, 0, runningYAB); + await device.queue.onSubmittedWorkDone(); + } + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + const t0 = performance.now(); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + return performance.now() - t0; + }; + + await dispatch(true); + log('info', 'warmup dispatch returned'); + + let correctness: 'pass' | 'fail' | 'skipped' = 'skipped'; + let correctness_first_fail: PerSizeResult['correctness_first_fail']; + + if (!SKIP_CORRECTNESS) { + const stagingX = device.createBuffer({ + size: TOTAL_PAIRS * fieldBytes, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + const stagingY = device.createBuffer({ + size: TOTAL_PAIRS * fieldBytes, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(runningXBuf, 0, stagingX, 0, TOTAL_PAIRS * fieldBytes); + enc.copyBufferToBuffer(runningYBuf, 0, stagingY, 0, TOTAL_PAIRS * fieldBytes); + device.queue.submit([enc.finish()]); + await Promise.all([stagingX.mapAsync(GPUMapMode.READ), stagingY.mapAsync(GPUMapMode.READ)]); + const gpuX = new Uint32Array(stagingX.getMappedRange().slice(0)); + const gpuY = new Uint32Array(stagingY.getMappedRange().slice(0)); + stagingX.unmap(); + stagingY.unmap(); + stagingX.destroy(); + stagingY.destroy(); + + const rInv = (() => { + let g = R % p; + let r = p; + let x = 0n; + let y = 1n; + while (g !== 0n) { + const q = r / g; + [r, g] = [g, r - q * g]; + [x, y] = [y, x - q * y]; + } + return ((x % p) + p) % p; + })(); + + let mismatches = 0; + const MAX_REPORTED = 4; + for (let i = 0; i < TOTAL_PAIRS; i++) { + const gxM = le32u32ToBi(gpuX, i * 8); + const gyM = le32u32ToBi(gpuY, i * 8); + const gx = (gxM * rInv) % p; + const gy = (gyM * rInv) % p; + const ex = pairs[i].r.x; + const ey = pairs[i].r.y; + if (gx !== ex || gy !== ey) { + mismatches++; + if (!correctness_first_fail) { + correctness_first_fail = { + i, + expected_x: ex.toString(), + got_x: gx.toString(), + expected_y: ey.toString(), + got_y: gy.toString(), + }; + } + if (mismatches > MAX_REPORTED) break; + } + } + correctness = mismatches === 0 ? 'pass' : 'fail'; + if (mismatches === 0) { + log('ok', `correctness: pass (${TOTAL_PAIRS}/${TOTAL_PAIRS} pairs match noble reference)`); + } else { + log('err', `correctness: FAIL (${mismatches}+ mismatches; first @ i=${correctness_first_fail!.i})`); + log('err', ` expected R.x = ${correctness_first_fail!.expected_x.slice(0, 24)}...`); + log('err', ` got R.x = ${correctness_first_fail!.got_x.slice(0, 24)}...`); + } + } + + const samples: number[] = []; + for (let r = 0; r < reps; r++) { + samples.push(await dispatch(false)); + } + const med = median(samples); + const mn = Math.min(...samples); + const mx = Math.max(...samples); + const nsPerPair = (med * 1e6) / TOTAL_PAIRS; + + log( + correctness === 'fail' ? 'err' : 'ok', + `B=${batchSize}: median=${med.toFixed(3)}ms min=${mn.toFixed(3)}ms max=${mx.toFixed(3)}ms ns/pair=${nsPerPair.toFixed(1)} correctness=${correctness}`, + ); + + valIdxBuf.destroy(); + newPointXBuf.destroy(); + newPointYBuf.destroy(); + runningXBuf.destroy(); + runningYBuf.destroy(); + ptmBuf.destroy(); + prefixBuf.destroy(); + countBuf.destroy(); + paramsBuf.destroy(); + + return { + batch_size: batchSize, + tpb, + bs, + num_wgs: numWgs, + total_pairs: TOTAL_PAIRS, + median_ms: med, + min_ms: mn, + max_ms: mx, + ns_per_pair: nsPerPair, + samples_ms: samples, + correctness, + correctness_first_fail, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + const reps = parseInt(qp.get('reps') ?? '5', 10); + if (!Number.isFinite(reps) || reps <= 0 || reps > 50) { + throw new Error(`?reps must be in (0, 50], got ${qp.get('reps')}`); + } + const totalStr = qp.get('total'); + if (totalStr !== null) { + const total = parseInt(totalStr, 10); + if (!Number.isFinite(total) || total <= 0 || total > (1 << 20)) { + throw new Error(`?total must be in (0, 2^20], got ${totalStr}`); + } + TOTAL_PAIRS = total; + } + const sizesStr = qp.get('sizes'); + if (sizesStr !== null) { + const sizes = sizesStr.split(',').map(s => parseInt(s, 10)); + for (const s of sizes) { + if (!Number.isFinite(s) || s <= 0 || s > 4096) { + throw new Error(`?sizes entries must be in (0, 4096], got ${s}`); + } + if (TOTAL_PAIRS % s !== 0) { + throw new Error(`?sizes entry ${s} does not divide TOTAL_PAIRS=${TOTAL_PAIRS}`); + } + const { tpb } = tpbBsFor(s); + if ((tpb & (tpb - 1)) !== 0) { + throw new Error(`?sizes entry ${s} yields TPB=${tpb} which is not a power of two`); + } + } + BATCH_SIZES = sizes; + } + if (qp.get('skip_correctness') === '1') { + SKIP_CORRECTNESS = true; + } + return { reps, total: TOTAL_PAIRS, sizes: BATCH_SIZES, skip_correctness: SKIP_CORRECTNESS }; +} + +async function main() { + try { + if (!('gpu' in navigator)) { + throw new Error('navigator.gpu missing — WebGPU not available'); + } + const params = parseParams(); + benchState.params = params; + log('info', `params: reps=${params.reps} total=${params.total} sizes=[${params.sizes.join(',')}] skip_correctness=${params.skip_correctness}`); + + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + + log('info', `generating ${TOTAL_PAIRS} on-curve pairs via noble (this can take a few seconds)…`); + const t0 = performance.now(); + const pairs = buildPairs(TOTAL_PAIRS, 0xc0ffee); + log('info', `pair generation done in ${(performance.now() - t0).toFixed(0)} ms`); + + const sm = new ShaderManager(4, TOTAL_PAIRS, BN254_CURVE_CONFIG, false); + + for (const B of BATCH_SIZES) { + try { + const r = await runOne(device, sm, B, params.reps, R, p, pairs); + benchState.results.push(r); + resultsClient.postProgress({ kind: 'batch_done', batch_size: B, median_ms: r.median_ms, ns_per_pair: r.ns_per_pair, correctness: r.correctness }); + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + log('err', `B=${B} failed: ${msg} — STOPPING sweep at first failure`); + benchState.state = 'error'; + benchState.error = msg; + return; + } + } + + benchState.state = 'done'; + log('ok', 'all batches done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + }) + .finally(() => { + postFinal().catch(() => {}); + }); diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-chain.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-chain.html new file mode 100644 index 000000000000..423d124585c3 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-chain.html @@ -0,0 +1,37 @@ + + + + + MSM bucket-accumulate pair-tree level-0 bench (marshal+chain) + + + +

MSM bucket-accumulate pair-tree level-0 bench (marshal + chain)

+

Query params: ?reps=R&pairs=N&buckets=B&s=S&wgi=W&disp=D

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-chain.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-chain.ts new file mode 100644 index 000000000000..b644b8ab1598 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-chain.ts @@ -0,0 +1,625 @@ +/// +// bench-msm-chain — Standalone WebGPU bench for the marshal + chain +// pair-tree level-0 pipeline that integrates the ba_rev_packed_carry +// chain kernel into MSM bucket accumulate. +// +// The harness: +// 1. Generates a point pool of N+1 random Montgomery points in the +// SoA-packed layout (2 planes, PG=2 vec4/elem). Index 0 is a +// universal decoy seed; indices 1..N are real points. +// 2. Generates a synthetic CSR: B buckets, each point assigned to a +// uniformly random bucket. csr_indices = points sorted by bucket; +// offset[] = CSR row pointers; count[b] = bucket b's point count. +// 3. Builds a chunk plan from the dense slices: for each bucket b +// with count[b] >= S, splits it into floor(count[b]/S) chunks of +// exactly S points. Total dense chunks T_dense. +// 4. Runs the marshal kernel: T_dense threads, each gathers S point +// coords from the pool into the strided chain layout. +// 5. Runs the recovered ba_rev_packed_carry chain kernel on the +// marshaled layout. +// 6. Times marshal and chain separately (DISP back-to-back dispatches +// per timed sample, amortising submit + drain). +// +// Reported metrics, sweeping S in {16, 32, 64}: +// marshal_ns_per_pt — ns per input point processed by marshal +// chain_ns_per_pt — ns per input point processed by chain +// combined_ns_per_pt — marshal + chain +// density — fraction of N points covered by dense chunks +// +// Out of scope here (covered by follow-on passes): +// - Reduction passes that fold pair-tree levels 1..log2(S)/2 into +// per-bucket totals. These reuse the same chain kernel with +// decreasing T; cost is ~25 ns/pt per level, ~3 levels for S=16. +// - Tail handling for buckets with count < S. Will reuse the +// workgroup-scan batch-affine path or a variable-length variant. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { makeResultsClient } from './results_post.js'; + +const PG = 2; +const DEFAULT_PAIRS = 1 << 17; // 131072 +const DEFAULT_BUCKETS = 1 << 13; // 8192 -> avg bucket size = 16 +const DEFAULT_WGI = 64; +const DEFAULT_DISP = 8; +const DEFAULT_S_SWEEP: readonly number[] = [16, 32, 64]; + +let PAIRS = DEFAULT_PAIRS; +let BUCKETS = DEFAULT_BUCKETS; +let WGI = DEFAULT_WGI; +let DISP = DEFAULT_DISP; +let S_SWEEP: readonly number[] = DEFAULT_S_SWEEP; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + for (;;) { + let v = 0n; + for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff); + v &= (1n << BigInt(bitlen)) - 1n; + if (v > 0n && v < p) return v; + } +} + +function bigintToPackedU32x8(v: bigint): Uint32Array { + const w = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + w[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return w; +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +// Pack a pool of M (= N+1) random Montgomery points into the SoA layout +// the marshal kernel expects: 2 planes (P.x, P.y), PG vec4 per element. +// Pool index 0 is the decoy seed. +function buildPointPool(poolSize: number, R: bigint, p: bigint, rng: () => number): Uint32Array { + const M = poolSize; + const buf = new Uint32Array(2 * PG * M * 4); + for (let e = 0; e < M; e++) { + const x = (randomBelow(p, rng) * R) % p; + const y = (randomBelow(p, rng) * R) % p; + const wx = bigintToPackedU32x8(x); + const wy = bigintToPackedU32x8(y); + for (let v = 0; v < PG; v++) { + const baseX = ((0 * PG + v) * M + e) * 4; + const baseY = ((1 * PG + v) * M + e) * 4; + buf[baseX + 0] = wx[4 * v + 0]; + buf[baseX + 1] = wx[4 * v + 1]; + buf[baseX + 2] = wx[4 * v + 2]; + buf[baseX + 3] = wx[4 * v + 3]; + buf[baseY + 0] = wy[4 * v + 0]; + buf[baseY + 1] = wy[4 * v + 1]; + buf[baseY + 2] = wy[4 * v + 2]; + buf[baseY + 3] = wy[4 * v + 3]; + } + } + return buf; +} + +// Synthetic CSR: assigns each of the N points to a random bucket in +// [0, B). Returns csr_indices (point indices sorted by bucket, values in +// [1, N+1)) and offsets[b] giving the first csr_indices position for +// bucket b. Point indices are 1-based so index 0 in the pool is reserved +// for the decoy. +function buildSyntheticCSR( + N: number, + B: number, + rng: () => number, +): { csrIndices: Uint32Array; offsets: Uint32Array; counts: Uint32Array } { + const bucket = new Uint32Array(N); + const counts = new Uint32Array(B); + for (let i = 0; i < N; i++) { + const b = rng() % B; + bucket[i] = b; + counts[b]++; + } + const offsets = new Uint32Array(B + 1); + for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b]; + const cursor = new Uint32Array(B); + const csrIndices = new Uint32Array(N); + for (let i = 0; i < N; i++) { + const b = bucket[i]; + csrIndices[offsets[b] + cursor[b]++] = i + 1; // 1-based: 0 is decoy + } + return { csrIndices, offsets, counts }; +} + +// Walk the CSR row pointers and produce the dense chunk plan. For each +// bucket b with count[b] >= S, emit floor(count[b]/S) chunks each +// pointing at S consecutive csr_indices entries. tail = sum of leftover +// points (count[b] mod S) that this v1 bench skips. +function buildChunkPlan( + offsets: Uint32Array, + counts: Uint32Array, + S: number, +): { chunkPlan: Uint32Array; T: number; tailPoints: number } { + const B = counts.length; + let T = 0; + let tail = 0; + for (let b = 0; b < B; b++) { + const c = counts[b]; + T += Math.floor(c / S); + tail += c % S; + } + const chunkPlan = new Uint32Array(2 * T); + let t = 0; + for (let b = 0; b < B; b++) { + const c = counts[b]; + const nChunks = Math.floor(c / S); + for (let k = 0; k < nChunks; k++) { + chunkPlan[2 * t + 0] = b; + chunkPlan[2 * t + 1] = offsets[b] + k * S; + t++; + } + } + return { chunkPlan, T, tailPoints: tail }; +} + +interface PerSizeResult { + s: number; + wgi: number; + pairs: number; + buckets: number; + T: number; + tail_points: number; + density: number; + disp: number; + marshal_median_ms: number; + marshal_min_ms: number; + marshal_max_ms: number; + marshal_ns_per_pt: number; + chain_median_ms: number; + chain_min_ms: number; + chain_max_ms: number; + chain_ns_per_pt: number; + combined_ns_per_pt: number; + marshal_samples_ms: number[]; + chain_samples_ms: number[]; + sanity_ok: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { reps: number; pairs: number; buckets: number; wgi: number; disp: number; s_sweep: readonly number[] } | null; + results: PerSizeResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { + state: 'boot', + params: null, + results: [], + error: null, + log: [], +}; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const resultsClient = makeResultsClient({ page: 'bench-msm-chain' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, + params: benchState.params, + results: benchState.results, + error: benchState.error, + log: benchState.log, + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-msm-chain] ${msg}`); +} + +interface PipelineInfo { + pipeline: GPUComputePipeline; + layout: GPUBindGroupLayout; +} + +async function compile( + device: GPUDevice, + code: string, + cacheKey: string, + bindLayout: GPUBindGroupLayout, +): Promise { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const msg of info.messages) { + const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`; + if (msg.type === 'error') { + console.error(line); + log('err', line); + errLines.push(line); + hasError = true; + } else { + console.warn(line); + log('warn', line); + } + } + if (hasError) { + throw new Error(`WGSL compile failed for ${cacheKey}: ${errLines.join(' | ')}`); + } + return device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [bindLayout] }), + compute: { module, entryPoint: 'main' }, + }); +} + +function chainLayout(device: GPUDevice): GPUBindGroupLayout { + return device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); +} + +function marshalLayout(device: GPUDevice): GPUBindGroupLayout { + return device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); +} + +async function timeDispatchPasses( + device: GPUDevice, + pipeline: GPUComputePipeline, + bind: GPUBindGroup, + numWgs: number, + reps: number, + passes: number, +): Promise { + // warmup + { + const enc = device.createCommandEncoder(); + for (let pIdx = 0; pIdx < passes; pIdx++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + } + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + const samples: number[] = []; + for (let r = 0; r < reps; r++) { + const enc = device.createCommandEncoder(); + for (let pIdx = 0; pIdx < passes; pIdx++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + samples.push(performance.now() - t0); + } + return samples; +} + +async function readNonZero(device: GPUDevice, buf: GPUBuffer, u32Count: number): Promise { + const bytes = u32Count * 4; + const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, bytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const u32 = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + for (let i = 0; i < u32.length; i++) if (u32[i] !== 0) return true; + return false; +} + +async function runOne( + device: GPUDevice, + sm: ShaderManager, + s: number, + reps: number, + R: bigint, + p: bigint, + seed: number, +): Promise { + log('info', `=== S=${s}: PAIRS=${PAIRS} BUCKETS=${BUCKETS} WGI=${WGI} DISP=${DISP}`); + + const rng = makeRng(seed); + const poolSize = PAIRS + 1; // index 0 reserved as decoy + const poolU32 = buildPointPool(poolSize, R, p, rng); + const { csrIndices, counts } = buildSyntheticCSR(PAIRS, BUCKETS, rng); + // Recompute offsets locally (buildSyntheticCSR returns them too, but + // we only need it inside buildChunkPlan). + const offsets = new Uint32Array(BUCKETS + 1); + for (let b = 0; b < BUCKETS; b++) offsets[b + 1] = offsets[b] + counts[b]; + + const { chunkPlan, T, tailPoints } = buildChunkPlan(offsets, counts, s); + const density = (T * s) / PAIRS; + const numWgs = Math.ceil(T / WGI); + log( + 'info', + `chunk plan: T=${T} chunks (S=${s} each) -> ${T * s} pts (${(density * 100).toFixed(1)}% of ${PAIRS}); tail=${tailPoints} pts skipped`, + ); + + if (T === 0) { + throw new Error(`S=${s}: no dense chunks (every bucket has count<${s}). Try smaller S or fewer buckets.`); + } + + // GPU buffers. + const mkSb = (size: number, copyDst: boolean, copySrc: boolean): GPUBuffer => { + let usage = GPUBufferUsage.STORAGE; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + return device.createBuffer({ size, usage }); + }; + + const poolBuf = mkSb(poolU32.byteLength, true, false); + const csrBuf = mkSb(csrIndices.byteLength, true, false); + const chunkBuf = mkSb(chunkPlan.byteLength, true, false); + const chainBytes = 4 * PG * (T * s) * 4 * 4; // 4 planes * PG vec4/elem * (T*S) elems * 4 u32/vec4 * 4 B/u32 + const chainBuf = mkSb(chainBytes, false, true); + const marshalParams = device.createBuffer({ + size: 16, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + const chainParams = device.createBuffer({ + size: 16, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + + device.queue.writeBuffer(poolBuf, 0, poolU32); + device.queue.writeBuffer(csrBuf, 0, csrIndices); + device.queue.writeBuffer(chunkBuf, 0, chunkPlan); + // marshal: params = [T, poolSize, _, _] + device.queue.writeBuffer(marshalParams, 0, new Uint32Array([T, poolSize, 0, 0])); + // chain: params = [N_chain, T, _, _] with N_chain = T*S + device.queue.writeBuffer(chainParams, 0, new Uint32Array([T * s, T, 0, 0])); + + // Compile marshal pipeline. + const marshalCode = sm.gen_ba_marshal_chain_shader(WGI, s); + log('info', `marshal shader ${marshalCode.length} chars`); + const mLayout = marshalLayout(device); + const marshalPipeline = await compile(device, marshalCode, `marshal-W${WGI}-S${s}`, mLayout); + const marshalBind = device.createBindGroup({ + layout: mLayout, + entries: [ + { binding: 0, resource: { buffer: csrBuf } }, + { binding: 1, resource: { buffer: chunkBuf } }, + { binding: 2, resource: { buffer: poolBuf } }, + { binding: 3, resource: { buffer: chainBuf } }, + { binding: 4, resource: { buffer: marshalParams } }, + ], + }); + + // Compile chain pipeline. + const chainCode = sm.gen_ba_rev_packed_carry_bench_shader(WGI, s); + log('info', `chain shader ${chainCode.length} chars`); + const cLayout = chainLayout(device); + const chainPipeline = await compile(device, chainCode, `chain-W${WGI}-S${s}`, cLayout); + const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE }); + // Chain output buffer (separate from input chain_buf since the kernel + // writes its R outputs to a 2-plane output buffer). + const chainOutBytes = 2 * PG * (T * s) * 4 * 4; + const chainOutBuf = mkSb(chainOutBytes, false, true); + const chainBind = device.createBindGroup({ + layout: cLayout, + entries: [ + { binding: 0, resource: { buffer: chainBuf } }, + { binding: 1, resource: { buffer: dummy } }, + { binding: 2, resource: { buffer: chainOutBuf } }, + { binding: 3, resource: { buffer: chainParams } }, + ], + }); + + log('info', `marshal: numWgs=${numWgs}, ${T} threads, ${T * s} pts gathered/dispatch`); + log('info', `chain : numWgs=${numWgs}, ${T} threads, S=${s} adds/thread = ${T * s} pts processed/dispatch`); + + // Marshal must run at least once before chain (chain reads chain_buf). + // Warmup is built into timeDispatchPasses. + const marshalSamples = await timeDispatchPasses(device, marshalPipeline, marshalBind, numWgs, reps, DISP); + const chainSamples = await timeDispatchPasses(device, chainPipeline, chainBind, numWgs, reps, DISP); + + const sanityOk = await readNonZero(device, chainOutBuf, 8); + + const marshalMed = median(marshalSamples); + const chainMed = median(chainSamples); + const ptsPerSample = T * s * DISP; + const marshalNsPerPt = (marshalMed * 1e6) / ptsPerSample; + const chainNsPerPt = (chainMed * 1e6) / ptsPerSample; + const combinedNsPerPt = marshalNsPerPt + chainNsPerPt; + + log( + sanityOk ? 'ok' : 'err', + `S=${s}: marshal=${marshalNsPerPt.toFixed(2)}ns/pt chain=${chainNsPerPt.toFixed(2)}ns/pt combined=${combinedNsPerPt.toFixed(2)}ns/pt density=${(density * 100).toFixed(1)}% sanity=${sanityOk ? 'OK' : 'FAIL'}`, + ); + + poolBuf.destroy(); + csrBuf.destroy(); + chunkBuf.destroy(); + chainBuf.destroy(); + chainOutBuf.destroy(); + dummy.destroy(); + marshalParams.destroy(); + chainParams.destroy(); + + return { + s, + wgi: WGI, + pairs: PAIRS, + buckets: BUCKETS, + T, + tail_points: tailPoints, + density, + disp: DISP, + marshal_median_ms: marshalMed, + marshal_min_ms: Math.min(...marshalSamples), + marshal_max_ms: Math.max(...marshalSamples), + marshal_ns_per_pt: marshalNsPerPt, + chain_median_ms: chainMed, + chain_min_ms: Math.min(...chainSamples), + chain_max_ms: Math.max(...chainSamples), + chain_ns_per_pt: chainNsPerPt, + combined_ns_per_pt: combinedNsPerPt, + marshal_samples_ms: marshalSamples, + chain_samples_ms: chainSamples, + sanity_ok: sanityOk, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + const reps = parseInt(qp.get('reps') ?? '5', 10); + if (!Number.isFinite(reps) || reps <= 0 || reps > 50) { + throw new Error(`?reps must be in (0, 50], got ${qp.get('reps')}`); + } + const pairsStr = qp.get('pairs'); + if (pairsStr !== null) { + const v = parseInt(pairsStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > (1 << 20)) { + throw new Error(`?pairs must be in (0, 2^20], got ${pairsStr}`); + } + PAIRS = v; + } + const bucketsStr = qp.get('buckets'); + if (bucketsStr !== null) { + const v = parseInt(bucketsStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > (1 << 18)) { + throw new Error(`?buckets must be in (0, 2^18], got ${bucketsStr}`); + } + BUCKETS = v; + } + const wgiStr = qp.get('wgi'); + if (wgiStr !== null) { + const v = parseInt(wgiStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > 1024) { + throw new Error(`?wgi must be in (0, 1024], got ${wgiStr}`); + } + WGI = v; + } + const dispStr = qp.get('disp'); + if (dispStr !== null) { + const v = parseInt(dispStr, 10); + if (!Number.isFinite(v) || v <= 0 || v > 64) { + throw new Error(`?disp must be in (0, 64], got ${dispStr}`); + } + DISP = v; + } + const sStr = qp.get('s'); + if (sStr !== null) { + const list = sStr.split(',').map(v => parseInt(v, 10)); + for (const v of list) { + if (!Number.isFinite(v) || v <= 0 || v > 256) { + throw new Error(`?s entries must be in (0, 256], got ${v}`); + } + } + S_SWEEP = list; + } + return { reps, pairs: PAIRS, buckets: BUCKETS, wgi: WGI, disp: DISP, s_sweep: S_SWEEP }; +} + +async function main() { + try { + if (!('gpu' in navigator)) { + throw new Error('navigator.gpu missing — WebGPU not available'); + } + const params = parseParams(); + benchState.params = params; + log( + 'info', + `params: reps=${params.reps} pairs=${params.pairs} buckets=${params.buckets} wgi=${params.wgi} disp=${params.disp} s=[${params.s_sweep.join(',')}]`, + ); + + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + + const sm = new ShaderManager(4, PAIRS, BN254_CURVE_CONFIG, false); + + let seed = 0xc511; + for (const s of S_SWEEP) { + try { + const r = await runOne(device, sm, s, params.reps, R, p, seed); + benchState.results.push(r); + resultsClient.postProgress({ + kind: 'batch_done', + s, + marshal_ns_per_pt: r.marshal_ns_per_pt, + chain_ns_per_pt: r.chain_ns_per_pt, + combined_ns_per_pt: r.combined_ns_per_pt, + density: r.density, + sanity_ok: r.sanity_ok, + }); + seed += 0x10; + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + log('err', `S=${s} failed: ${msg} — STOPPING sweep at first failure`); + benchState.state = 'error'; + benchState.error = msg; + return; + } + } + + benchState.state = 'done'; + log('ok', 'all sizes done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + }) + .finally(() => { + postFinal().catch(() => {}); + }); diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.html new file mode 100644 index 000000000000..5126b5f73cc0 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.html @@ -0,0 +1,22 @@ + + + + + v2 prod orchestrator + indirect dispatch noble oracle (WebGPU) + + + +

v2 prod orchestrator + indirect dispatch noble oracle

+

Query params: ?subtasks=T&columns=B&input=N&s=S&wgi=W&tpb=TPB&per=PER&lvls=L&seed=K

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.ts new file mode 100644 index 000000000000..10e2be473956 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle-prod.ts @@ -0,0 +1,349 @@ +/// +// End-to-end correctness oracle for the prod v2 pair-tree orchestrator. +// +// Wraps runSmvpV2PairTree (cuzk/smvp_v2_pair_tree.ts) with a noble-CPU +// cross-check on real BN254 affine points. Validates the full prod +// path: csr_to_v2_meta + csr_to_v2_active_sums + planner_v2_prod + +// marshal_prod + disjoint_prod + scatter_prod + carry_prod + +// v2_to_running, with indirect dispatch driven by the planner's +// per-level totals. +// +// Sizing: small by design (num_subtasks=1, num_columns=32, N=256) so +// noble's projective bucket sum runs instantly in the browser. Each +// bucket gets ~8 points; pair-tree needs ~4 levels. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD, modInverse } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { runSmvpV2PairTree } from '../../src/msm_webgpu/cuzk/smvp_v2_pair_tree.js'; +import { makeResultsClient } from './results_post.js'; +import { bn254 } from '@noble/curves/bn254'; + +let NUM_SUBTASKS = 1; +let NUM_COLUMNS = 32; +let INPUT_SIZE = 256; +let S = 16; +let WGI = 64; +let TPB = 64; +let PER_THREAD = 1; +let MAX_LEVELS = 8; +let SEED = 0xc0de; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function bigintToPackedU32x8(v: bigint): Uint32Array { + const w = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + w[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return w; +} + +function packedU32x8ToBigint(w: Uint32Array, off: number): bigint { + let v = 0n; + for (let i = 7; i >= 0; i--) v = (v << 32n) | BigInt(w[off + i] >>> 0); + return v; +} + +interface OracleResult { + num_subtasks: number; + num_columns: number; + input_size: number; + s: number; + wgi: number; + tpb: number; + per_thread: number; + max_levels: number; + total_passes: number; + gpu_wall_ms: number; + buckets_checked: number; + buckets_passed: number; + first_mismatches: Array<{ subtask: number; bucket: number; count: number; gpu_x: string; gpu_y: string; ref_x: string; ref_y: string; ok: boolean }>; + all_passed: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: Record | null; + results: OracleResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] }; +(window as unknown as { __bench: BenchState }).__bench = benchState; +const resultsClient = makeResultsClient({ page: 'bench-msm-oracle-prod' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, + params: benchState.params, + results: benchState.results, + error: benchState.error, + log: benchState.log, + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-msm-oracle-prod] ${msg}`); +} + +async function readbackU32(device: GPUDevice, buf: GPUBuffer, bytes: number): Promise { + const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, bytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const out = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + return out; +} + +async function runOracle(device: GPUDevice, sm: ShaderManager, R: bigint, Rinv: bigint, p: bigint): Promise { + log('info', `=== T=${NUM_SUBTASKS} B=${NUM_COLUMNS} N=${INPUT_SIZE} S=${S} WGI=${WGI} TPB=${TPB} PER=${PER_THREAD} MAX_LEVELS=${MAX_LEVELS}`); + const rng = makeRng(SEED); + const G1 = bn254.G1.ProjectivePoint; + const order = bn254.fields.Fr.ORDER; + + const points: Array<{ x: bigint; y: bigint }> = []; + const pointXWords = new Uint32Array(INPUT_SIZE * 8); + const pointYWords = new Uint32Array(INPUT_SIZE * 8); + for (let i = 0; i < INPUT_SIZE; i++) { + let k = 0n; + for (let w = 0; w < 8; w++) k = (k << 32n) | BigInt(rng() >>> 0); + k = k % order; + if (k === 0n) k = 1n; + const aff = G1.BASE.multiply(k).toAffine(); + points.push({ x: aff.x, y: aff.y }); + pointXWords.set(bigintToPackedU32x8((aff.x * R) % p), 8 * i); + pointYWords.set(bigintToPackedU32x8((aff.y * R) % p), 8 * i); + } + log('info', `generated ${INPUT_SIZE} BN254 affine points`); + + const valIdxArr = new Uint32Array(NUM_SUBTASKS * INPUT_SIZE); + const rowPtrArr = new Uint32Array(NUM_SUBTASKS * (NUM_COLUMNS + 1)); + const bucketOf: Uint32Array[] = []; + for (let st = 0; st < NUM_SUBTASKS; st++) { + const bucket = new Uint32Array(INPUT_SIZE); + const counts = new Uint32Array(NUM_COLUMNS); + for (let i = 0; i < INPUT_SIZE; i++) { + const hi = (rng() >>> 16) & 0xffff; + const lo = (rng() >>> 16) & 0xffff; + const v = hi * 0x10000 + lo; + const b = v % NUM_COLUMNS; + bucket[i] = b; + counts[b]++; + } + bucketOf.push(bucket); + const offsets = new Uint32Array(NUM_COLUMNS + 1); + for (let b = 0; b < NUM_COLUMNS; b++) offsets[b + 1] = offsets[b] + counts[b]; + const cursor = new Uint32Array(NUM_COLUMNS); + for (let i = 0; i < INPUT_SIZE; i++) { + const b = bucket[i]; + const slot = offsets[b] + cursor[b]++; + valIdxArr[st * INPUT_SIZE + slot] = i; + } + const rpBase = st * (NUM_COLUMNS + 1); + for (let b = 0; b <= NUM_COLUMNS; b++) rowPtrArr[rpBase + b] = offsets[b]; + } + log('info', `built synthetic CSR for ${NUM_SUBTASKS} window(s)`); + + const mk = (bytes: number, extra: GPUBufferUsageFlags = 0): GPUBuffer => + device.createBuffer({ size: bytes, usage: GPUBufferUsage.STORAGE | extra }); + const val_idx_buf = mk(valIdxArr.byteLength, GPUBufferUsage.COPY_DST); + const row_ptr_buf = mk(rowPtrArr.byteLength, GPUBufferUsage.COPY_DST); + const point_x_buf = mk(INPUT_SIZE * 32, GPUBufferUsage.COPY_DST); + const point_y_buf = mk(INPUT_SIZE * 32, GPUBufferUsage.COPY_DST); + const running_x_buf = mk(NUM_SUBTASKS * NUM_COLUMNS * 32, GPUBufferUsage.COPY_SRC); + const running_y_buf = mk(NUM_SUBTASKS * NUM_COLUMNS * 32, GPUBufferUsage.COPY_SRC); + const bucket_active_buf = mk(NUM_SUBTASKS * NUM_COLUMNS * 4, GPUBufferUsage.COPY_SRC); + + device.queue.writeBuffer(val_idx_buf, 0, valIdxArr as BufferSource); + device.queue.writeBuffer(row_ptr_buf, 0, rowPtrArr as BufferSource); + device.queue.writeBuffer(point_x_buf, 0, pointXWords as BufferSource); + device.queue.writeBuffer(point_y_buf, 0, pointYWords as BufferSource); + + const stats = await runSmvpV2PairTree({ + device, + shaderManager: sm, + num_subtasks: NUM_SUBTASKS, + num_columns: NUM_COLUMNS, + input_size: INPUT_SIZE, + s: S, + tpb: TPB, + per_thread: PER_THREAD, + wgi: WGI, + max_levels: MAX_LEVELS, + val_idx_buf, row_ptr_buf, point_x_buf, point_y_buf, + running_x_buf, running_y_buf, bucket_active_buf, + }); + log('info', `runSmvpV2PairTree returned: ${JSON.stringify(stats)}`); + + const runningXWords = await readbackU32(device, running_x_buf, NUM_SUBTASKS * NUM_COLUMNS * 32); + const runningYWords = await readbackU32(device, running_y_buf, NUM_SUBTASKS * NUM_COLUMNS * 32); + const bucketActive = await readbackU32(device, bucket_active_buf, NUM_SUBTASKS * NUM_COLUMNS * 4); + + const refSumPerBucket = new Map(); + for (let st = 0; st < NUM_SUBTASKS; st++) { + const bucket = bucketOf[st]; + for (let b = 0; b < NUM_COLUMNS; b++) { + let acc = G1.ZERO; + let count = 0; + for (let i = 0; i < INPUT_SIZE; i++) { + if (bucket[i] !== b) continue; + acc = acc.add(G1.fromAffine({ x: points[i].x, y: points[i].y })); + count++; + } + const bucket_global = st * NUM_COLUMNS + b; + refSumPerBucket.set(bucket_global, count === 0 ? null : (acc.is0() ? null : acc.toAffine())); + } + } + + const checks: OracleResult['first_mismatches'] = []; + const mismatches: OracleResult['first_mismatches'] = []; + let passCount = 0; + for (let st = 0; st < NUM_SUBTASKS; st++) { + const bucket = bucketOf[st]; + const counts = new Uint32Array(NUM_COLUMNS); + for (let i = 0; i < INPUT_SIZE; i++) counts[bucket[i]]++; + for (let b = 0; b < NUM_COLUMNS; b++) { + const bucket_global = st * NUM_COLUMNS + b; + const ref = refSumPerBucket.get(bucket_global) ?? null; + const active = bucketActive[bucket_global]; + if (counts[b] === 0) { + if (active !== 0) mismatches.push({ subtask: st, bucket: b, count: 0, gpu_x: 'active=1', gpu_y: '', ref_x: 'empty', ref_y: '', ok: false }); + continue; + } + if (active === 0) { + mismatches.push({ subtask: st, bucket: b, count: counts[b], gpu_x: 'active=0', gpu_y: '', ref_x: ref ? ref.x.toString(16) : 'INF', ref_y: ref ? ref.y.toString(16) : 'INF', ok: false }); + continue; + } + const xMont = packedU32x8ToBigint(runningXWords, bucket_global * 8); + const yMont = packedU32x8ToBigint(runningYWords, bucket_global * 8); + const gx = (xMont * Rinv) % p; + const gy = (yMont * Rinv) % p; + const ok = ref !== null && gx === ref.x && gy === ref.y; + const entry = { subtask: st, bucket: b, count: counts[b], gpu_x: gx.toString(16), gpu_y: gy.toString(16), ref_x: ref ? ref.x.toString(16) : 'INF', ref_y: ref ? ref.y.toString(16) : 'INF', ok }; + checks.push(entry); + if (ok) passCount++; + else if (mismatches.length < 8) mismatches.push(entry); + } + } + const allPassed = mismatches.length === 0; + + if (allPassed) { + log('ok', `oracle PASS — ${passCount}/${checks.length} buckets match noble reference (prod orchestrator)`); + } else { + log('err', `oracle FAIL — ${mismatches.length} mismatches in first ${checks.length} buckets`); + for (const m of mismatches.slice(0, 8)) { + log('err', ` subtask=${m.subtask} bucket=${m.bucket} count=${m.count}`); + log('err', ` gpu: x=${m.gpu_x} y=${m.gpu_y}`); + log('err', ` ref: x=${m.ref_x} y=${m.ref_y}`); + } + } + + val_idx_buf.destroy(); + row_ptr_buf.destroy(); + point_x_buf.destroy(); + point_y_buf.destroy(); + running_x_buf.destroy(); + running_y_buf.destroy(); + bucket_active_buf.destroy(); + + return { + num_subtasks: NUM_SUBTASKS, + num_columns: NUM_COLUMNS, + input_size: INPUT_SIZE, + s: S, + wgi: WGI, + tpb: TPB, + per_thread: PER_THREAD, + max_levels: MAX_LEVELS, + total_passes: stats.total_passes, + gpu_wall_ms: stats.gpu_wall_ms, + buckets_checked: checks.length, + buckets_passed: passCount, + first_mismatches: mismatches, + all_passed: allPassed, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + if (qp.get('subtasks')) NUM_SUBTASKS = parseInt(qp.get('subtasks')!, 10); + if (qp.get('columns')) NUM_COLUMNS = parseInt(qp.get('columns')!, 10); + if (qp.get('input')) INPUT_SIZE = parseInt(qp.get('input')!, 10); + if (qp.get('s')) S = parseInt(qp.get('s')!, 10); + if (qp.get('wgi')) WGI = parseInt(qp.get('wgi')!, 10); + if (qp.get('tpb')) TPB = parseInt(qp.get('tpb')!, 10); + if (qp.get('per')) PER_THREAD = parseInt(qp.get('per')!, 10); + if (qp.get('lvls')) MAX_LEVELS = parseInt(qp.get('lvls')!, 10); + if (qp.get('seed')) SEED = parseInt(qp.get('seed')!, 10); + return { subtasks: NUM_SUBTASKS, columns: NUM_COLUMNS, input: INPUT_SIZE, s: S, wgi: WGI, tpb: TPB, per_thread: PER_THREAD, max_levels: MAX_LEVELS, seed: SEED }; +} + +async function main() { + try { + if (!('gpu' in navigator)) throw new Error('navigator.gpu missing'); + const params = parseParams(); + benchState.params = params; + log('info', `params: ${JSON.stringify(params)}`); + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + const Rinv = modInverse(R, p); + const sm = new ShaderManager(NUM_SUBTASKS, NUM_COLUMNS, BN254_CURVE_CONFIG, false); + const r = await runOracle(device, sm, R, Rinv, p); + benchState.results.push(r); + resultsClient.postProgress({ + kind: 'oracle_prod_done', + all_passed: r.all_passed, + buckets_passed: r.buckets_passed, + buckets_checked: r.buckets_checked, + gpu_wall_ms: r.gpu_wall_ms, + }); + benchState.state = 'done'; + log('ok', 'done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + }) + .finally(() => { + postFinal().catch(() => {}); + }); diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.html new file mode 100644 index 000000000000..8296a4e0844b --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.html @@ -0,0 +1,22 @@ + + + + + v2 pair-tree bucket-accumulate noble-CPU oracle (WebGPU) + + + +

v2 pair-tree bucket-accumulate noble-CPU oracle

+

Query params: ?n=N&buckets=B&s=S&wgi=W&seed=K

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.ts new file mode 100644 index 000000000000..2addcec51fae --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-oracle.ts @@ -0,0 +1,625 @@ +/// +// End-to-end correctness oracle for the v2 bin-packed pair-tree +// bucket-accumulate pipeline. Feeds REAL BN254 affine points (random +// scalar * G, via @noble/curves) into the pair-tree and verifies that +// the per-bucket reduced sum matches a noble-projective reference. +// +// This is the test fused_revcarry never had: a ground-truth oracle on +// real curve data. If this passes, the v2 pair-tree's bucket-accumulate +// math (disjoint pair-sum + suffix-product single-fr_inv + lean affine +// add) is correct end-to-end on real BN254 points. +// +// Scope: validates the round kernel + planner only. BPR / horner / +// finalize are NOT part of v2 yet — they're step 3+ of the rewrite plan. +// This oracle stops at "per-bucket sum is correct". +// +// Sizing: tiny by design — N=256 points, B=32 buckets, single window +// (no signed slicing). Each bucket gets ~8 points; the pair-tree +// reduces in ~4 levels. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD, modInverse } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { makeResultsClient } from './results_post.js'; +import { bn254 } from '@noble/curves/bn254'; + +const PG = 2; +let NPTS = 256; +let BUCKETS = 32; +let S = 16; +let WGI = 64; +let SEED = 0xa110ce; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function bigintToPackedU32x8(v: bigint): Uint32Array { + const w = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + w[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return w; +} + +function packedU32x8ToBigint(w: Uint32Array, off: number): bigint { + let v = 0n; + for (let i = 7; i >= 0; i--) v = (v << 32n) | BigInt(w[off + i] >>> 0); + return v; +} + +function makeSoABuf(device: GPUDevice, M: number, copyDst: boolean, copySrc: boolean): GPUBuffer { + const bytes = 2 * PG * M * 4 * 4; + let usage = GPUBufferUsage.STORAGE; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + return device.createBuffer({ size: bytes, usage }); +} + +interface CurvePoint { + x: bigint; + y: bigint; +} + +function buildL0WithRealPoints( + N: number, + B: number, + R: bigint, + p: bigint, + rng: () => number, +): { + initBuf: Uint32Array; + initCounts: Uint32Array; + initOffsets: Uint32Array; + M: number; + points: CurvePoint[]; + bucket: Uint32Array; +} { + const M = N + 2; + const buf = new Uint32Array(2 * PG * M * 4); + const G1 = bn254.G1.ProjectivePoint; + const order = bn254.fields.Fr.ORDER; + + const points: CurvePoint[] = []; + const xWords = new Uint32Array(8 * M); + const yWords = new Uint32Array(8 * M); + + for (let i = 0; i < N; i++) { + let k = 0n; + for (let w = 0; w < 8; w++) k = (k << 32n) | BigInt(rng() >>> 0); + k = k % order; + if (k === 0n) k = 1n; + const aff = G1.BASE.multiply(k).toAffine(); + points.push({ x: aff.x, y: aff.y }); + const xMont = (aff.x * R) % p; + const yMont = (aff.y * R) % p; + xWords.set(bigintToPackedU32x8(xMont), 8 * i); + yWords.set(bigintToPackedU32x8(yMont), 8 * i); + } + + for (let pad = 0; pad < 2; pad++) { + const i = N + pad; + let xCand: bigint; + do { + xCand = 0n; + for (let w = 0; w < 8; w++) xCand = (xCand << 32n) | BigInt(rng() >>> 0); + xCand = xCand % p; + } while (xCand === 0n); + const yCand = ((xCand + 1n + BigInt(pad)) * 7n) % p; + const xMont = (xCand * R) % p; + const yMont = (yCand * R) % p; + xWords.set(bigintToPackedU32x8(xMont), 8 * i); + yWords.set(bigintToPackedU32x8(yMont), 8 * i); + } + + const bucket = new Uint32Array(N); + const counts = new Uint32Array(B); + for (let i = 0; i < N; i++) { + const hi = (rng() >>> 16) & 0xffff; + const lo = (rng() >>> 16) & 0xffff; + const v = hi * 0x10000 + lo; + const b = v % B; + bucket[i] = b; + counts[b]++; + } + const offsets = new Uint32Array(B + 1); + for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b]; + + const cursor = new Uint32Array(B); + const writeElem = (planeIdx: number, dstIdx: number, words: Uint32Array, srcOff: number) => { + const planeBase = planeIdx * PG * M; + for (let v = 0; v < PG; v++) { + const base = (planeBase + PG * dstIdx + v) * 4; + buf[base + 0] = words[srcOff + 4 * v + 0]; + buf[base + 1] = words[srcOff + 4 * v + 1]; + buf[base + 2] = words[srcOff + 4 * v + 2]; + buf[base + 3] = words[srcOff + 4 * v + 3]; + } + }; + for (let i = 0; i < N; i++) { + const b = bucket[i]; + const dst = offsets[b] + cursor[b]++; + writeElem(0, dst, xWords, 8 * i); + writeElem(1, dst, yWords, 8 * i); + } + writeElem(0, M - 2, xWords, 8 * (M - 2)); + writeElem(1, M - 2, yWords, 8 * (M - 2)); + writeElem(0, M - 1, xWords, 8 * (M - 1)); + writeElem(1, M - 1, yWords, 8 * (M - 1)); + + return { initBuf: buf, initCounts: counts, initOffsets: offsets, M, points, bucket }; +} + +function buildLevelPlan( + counts: Uint32Array, + offsets: Uint32Array, + s: number, + padLIdx: number, + padRIdx: number, + discardIdx: number, +) { + const B = counts.length; + let totalPairs = 0; + let totalCarries = 0; + const newCounts = new Uint32Array(B); + for (let b = 0; b < B; b++) { + const n = counts[b]; + const p = Math.floor(n / 2); + const c = n & 1; + totalPairs += p; + totalCarries += c; + newCounts[b] = p + c; + } + const newOffsets = new Uint32Array(B + 1); + for (let b = 0; b < B; b++) newOffsets[b + 1] = newOffsets[b] + newCounts[b]; + + const numChunks = Math.max(1, Math.ceil(totalPairs / s)); + const chunkPlan = new Uint32Array(2 * s * numChunks); + const scatterPlan = new Uint32Array(s * numChunks); + const carryPlan = new Uint32Array(2 * Math.max(1, totalCarries)); + + for (let i = 0; i < numChunks * s; i++) { + chunkPlan[2 * i + 0] = padLIdx; + chunkPlan[2 * i + 1] = padRIdx; + scatterPlan[i] = discardIdx; + } + + let slot = 0; + let carryIdx = 0; + for (let b = 0; b < B; b++) { + const n = counts[b]; + const p = Math.floor(n / 2); + for (let j = 0; j < p; j++) { + chunkPlan[2 * slot + 0] = offsets[b] + 2 * j; + chunkPlan[2 * slot + 1] = offsets[b] + 2 * j + 1; + scatterPlan[slot] = newOffsets[b] + j; + slot++; + } + if (n & 1) { + carryPlan[2 * carryIdx + 0] = offsets[b] + n - 1; + carryPlan[2 * carryIdx + 1] = newOffsets[b] + p; + carryIdx++; + } + } + return { chunkPlan, scatterPlan, carryPlan, newCounts, newOffsets, numChunks, numCarries: totalCarries, totalPairs }; +} + +interface BucketCheck { + bucket: number; + count: number; + gpu_x: string; + gpu_y: string; + ref_x: string; + ref_y: string; + ok: boolean; +} + +interface OracleResult { + n: number; + buckets: number; + s: number; + wgi: number; + levels: number; + total_pair_adds: number; + buckets_checked: number; + buckets_passed: number; + first_mismatches: BucketCheck[]; + all_passed: boolean; + gpu_wall_ms: number; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: Record | null; + results: OracleResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] }; +(window as unknown as { __bench: BenchState }).__bench = benchState; +const resultsClient = makeResultsClient({ page: 'bench-msm-oracle' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, + params: benchState.params, + results: benchState.results, + error: benchState.error, + log: benchState.log, + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-msm-oracle] ${msg}`); +} + +async function compileOne(device: GPUDevice, code: string, key: string, layout: GPUBindGroupLayout): Promise { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const m of info.messages) { + const line = `[shader ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`; + if (m.type === 'error') { + console.error(line); + log('err', line); + errLines.push(line); + hasError = true; + } else { + console.warn(line); + } + } + if (hasError) throw new Error(`WGSL compile failed for ${key}: ${errLines.slice(0, 4).join(' | ')}`); + return device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); +} + +function ioLayout4(device: GPUDevice): GPUBindGroupLayout { + return device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); +} + +async function readbackU32(device: GPUDevice, buf: GPUBuffer, bytes: number): Promise { + const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, bytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const out = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + return out; +} + +async function runOracle(device: GPUDevice, sm: ShaderManager, R: bigint, Rinv: bigint, p: bigint): Promise { + log('info', `=== N=${NPTS} B=${BUCKETS} S=${S} WGI=${WGI}`); + const rng = makeRng(SEED); + const { initBuf, initCounts, initOffsets, M, points, bucket } = buildL0WithRealPoints(NPTS, BUCKETS, R, p, rng); + + let cMin = NPTS, cMax = 0, cZero = 0; + for (let b = 0; b < BUCKETS; b++) { + if (initCounts[b] > cMax) cMax = initCounts[b]; + if (initCounts[b] < cMin) cMin = initCounts[b]; + if (initCounts[b] === 0) cZero++; + } + log('info', `built L0: M=${M} bucket counts min=${cMin} max=${cMax} zero=${cZero}/${BUCKETS}`); + + const padLIdx = M - 2; + const padRIdx = M - 1; + const discardIdx = M - 2; + + const bufA = makeSoABuf(device, M, true, true); + const bufB = makeSoABuf(device, M, true, true); + device.queue.writeBuffer(bufA, 0, initBuf); + device.queue.writeBuffer(bufB, 0, initBuf); + + const maxL0Chunks = Math.ceil(NPTS / 2 / S) + 1; + const chainBuf = makeSoABuf(device, 2 * S * maxL0Chunks, false, false); + const tempOutBuf = makeSoABuf(device, S * maxL0Chunks, false, true); + + const layoutMarshal = ioLayout4(device); + const layoutDisjoint = ioLayout4(device); + const layoutScatter = ioLayout4(device); + const layoutCarry = ioLayout4(device); + const marshalPipe = await compileOne(device, sm.gen_ba_marshal_pairs_bench_shader(WGI, S), `marshal-W${WGI}-S${S}`, layoutMarshal); + const disjointPipe = await compileOne(device, sm.gen_ba_pair_disjoint_tree_bench_shader(WGI, S), `disjoint-W${WGI}-S${S}`, layoutDisjoint); + const scatterPipe = await compileOne(device, sm.gen_ba_scatter_pairs_bench_shader(WGI, S), `scatter-W${WGI}-S${S}`, layoutScatter); + const carryPipe = await compileOne(device, sm.gen_ba_carry_copy_bench_shader(WGI), `carry-W${WGI}`, layoutCarry); + log('info', '4 pipelines compiled'); + + let counts = initCounts; + let offsets = initOffsets; + let finalOffsets: Uint32Array = initOffsets; + let curIn: GPUBuffer = bufA; + let curOut: GPUBuffer = bufB; + let totalPairAdds = 0; + let levelIdx = 0; + const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE }); + + interface PassSpec { pipeline: GPUComputePipeline; bind: GPUBindGroup; numWgs: number } + const allPasses: PassSpec[] = []; + const levelBufHolders: GPUBuffer[] = []; + + for (;;) { + let maxCount = 0; + for (let b = 0; b < counts.length; b++) if (counts[b] > maxCount) maxCount = counts[b]; + if (maxCount <= 1) { + finalOffsets = offsets; + break; + } + if (levelIdx > 24) throw new Error('exceeded safety level cap'); + + const plan = buildLevelPlan(counts, offsets, S, padLIdx, padRIdx, discardIdx); + totalPairAdds += plan.totalPairs; + const T = plan.numChunks; + const numWgs = Math.ceil(T / WGI); + log('info', `L${levelIdx}: T=${T} pairs=${plan.totalPairs} carries=${plan.numCarries} maxCount=${maxCount}`); + + const chunkPlanBuf = device.createBuffer({ size: plan.chunkPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(chunkPlanBuf, 0, plan.chunkPlan); + const scatterPlanBuf = device.createBuffer({ size: plan.scatterPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(scatterPlanBuf, 0, plan.scatterPlan); + const carryPlanBuf = device.createBuffer({ size: plan.carryPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(carryPlanBuf, 0, plan.carryPlan); + + const marshalParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(marshalParams, 0, new Uint32Array([T, M, 0, 0])); + const disjointParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(disjointParams, 0, new Uint32Array([2 * S * T, T, 1, 0])); + const scatterParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(scatterParams, 0, new Uint32Array([T, M, 0, 0])); + const carryParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + if (plan.numCarries > 0) { + device.queue.writeBuffer(carryParams, 0, new Uint32Array([plan.numCarries, M, M, 0])); + } + + const marshalBind = device.createBindGroup({ + layout: layoutMarshal, + entries: [ + { binding: 0, resource: { buffer: chunkPlanBuf } }, + { binding: 1, resource: { buffer: curIn } }, + { binding: 2, resource: { buffer: chainBuf } }, + { binding: 3, resource: { buffer: marshalParams } }, + ], + }); + const disjointBind = device.createBindGroup({ + layout: layoutDisjoint, + entries: [ + { binding: 0, resource: { buffer: chainBuf } }, + { binding: 1, resource: { buffer: dummy } }, + { binding: 2, resource: { buffer: tempOutBuf } }, + { binding: 3, resource: { buffer: disjointParams } }, + ], + }); + const scatterBind = device.createBindGroup({ + layout: layoutScatter, + entries: [ + { binding: 0, resource: { buffer: scatterPlanBuf } }, + { binding: 1, resource: { buffer: tempOutBuf } }, + { binding: 2, resource: { buffer: curOut } }, + { binding: 3, resource: { buffer: scatterParams } }, + ], + }); + let carryBind: GPUBindGroup | null = null; + if (plan.numCarries > 0) { + carryBind = device.createBindGroup({ + layout: layoutCarry, + entries: [ + { binding: 0, resource: { buffer: carryPlanBuf } }, + { binding: 1, resource: { buffer: curIn } }, + { binding: 2, resource: { buffer: curOut } }, + { binding: 3, resource: { buffer: carryParams } }, + ], + }); + } + + allPasses.push({ pipeline: marshalPipe, bind: marshalBind, numWgs }); + allPasses.push({ pipeline: disjointPipe, bind: disjointBind, numWgs }); + allPasses.push({ pipeline: scatterPipe, bind: scatterBind, numWgs }); + if (plan.numCarries > 0 && carryBind) { + const carryWgs = Math.ceil(plan.numCarries / WGI); + allPasses.push({ pipeline: carryPipe, bind: carryBind, numWgs: carryWgs }); + } + levelBufHolders.push(chunkPlanBuf, scatterPlanBuf, carryPlanBuf, marshalParams, disjointParams, scatterParams, carryParams); + + counts = plan.newCounts; + offsets = plan.newOffsets; + [curIn, curOut] = [curOut, curIn]; + levelIdx++; + } + + const enc = device.createCommandEncoder(); + for (const ps of allPasses) { + const pass = enc.beginComputePass(); + pass.setPipeline(ps.pipeline); + pass.setBindGroup(0, ps.bind); + pass.dispatchWorkgroups(ps.numWgs, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + const gpuWall = performance.now() - t0; + log('info', `batched ${allPasses.length} passes in one submit: ${gpuWall.toFixed(2)} ms`); + + const result = await readbackU32(device, curIn, 2 * PG * M * 4 * 4); + + bufA.destroy(); + bufB.destroy(); + chainBuf.destroy(); + tempOutBuf.destroy(); + dummy.destroy(); + for (const b of levelBufHolders) b.destroy(); + + const decodeAt = (slot: number): { x_mont: bigint; y_mont: bigint } => { + const xWords = new Uint32Array(8); + const yWords = new Uint32Array(8); + const planeBaseX = 0 * PG * M; + const planeBaseY = 1 * PG * M; + for (let v = 0; v < PG; v++) { + const baseX = (planeBaseX + PG * slot + v) * 4; + const baseY = (planeBaseY + PG * slot + v) * 4; + xWords[4 * v + 0] = result[baseX + 0]; + xWords[4 * v + 1] = result[baseX + 1]; + xWords[4 * v + 2] = result[baseX + 2]; + xWords[4 * v + 3] = result[baseX + 3]; + yWords[4 * v + 0] = result[baseY + 0]; + yWords[4 * v + 1] = result[baseY + 1]; + yWords[4 * v + 2] = result[baseY + 2]; + yWords[4 * v + 3] = result[baseY + 3]; + } + const x_mont = packedU32x8ToBigint(xWords, 0); + const y_mont = packedU32x8ToBigint(yWords, 0); + return { x_mont, y_mont }; + }; + + const G1 = bn254.G1.ProjectivePoint; + const refSumPerBucket = new Map(); + for (let b = 0; b < BUCKETS; b++) { + if (initCounts[b] === 0) continue; + let acc = G1.ZERO; + for (let i = 0; i < NPTS; i++) { + if (bucket[i] !== b) continue; + acc = acc.add(G1.fromAffine({ x: points[i].x, y: points[i].y })); + } + refSumPerBucket.set(b, acc.is0() ? null : acc.toAffine()); + } + + const checks: BucketCheck[] = []; + const mismatches: BucketCheck[] = []; + let passCount = 0; + for (let b = 0; b < BUCKETS; b++) { + if (initCounts[b] === 0) continue; + const slot = finalOffsets[b]; + const { x_mont, y_mont } = decodeAt(slot); + const gx = (x_mont * Rinv) % p; + const gy = (y_mont * Rinv) % p; + const ref = refSumPerBucket.get(b); + let ok = false; + if (ref === null) { + ok = gx === 0n && gy === 0n; + } else if (ref) { + ok = gx === ref.x && gy === ref.y; + } + const entry: BucketCheck = { + bucket: b, + count: initCounts[b], + gpu_x: gx.toString(16), + gpu_y: gy.toString(16), + ref_x: ref ? ref.x.toString(16) : 'INF', + ref_y: ref ? ref.y.toString(16) : 'INF', + ok, + }; + checks.push(entry); + if (ok) passCount++; + else if (mismatches.length < 8) mismatches.push(entry); + } + const allPassed = mismatches.length === 0 && passCount === checks.length; + + if (allPassed) { + log('ok', `oracle PASS — ${passCount}/${checks.length} buckets match noble reference`); + } else { + log('err', `oracle FAIL — ${checks.length - passCount}/${checks.length} buckets diverged (showing first ${mismatches.length})`); + for (const m of mismatches) { + log('err', ` bucket ${m.bucket} (count=${m.count})`); + log('err', ` gpu: x=${m.gpu_x} y=${m.gpu_y}`); + log('err', ` ref: x=${m.ref_x} y=${m.ref_y}`); + } + } + + return { + n: NPTS, + buckets: BUCKETS, + s: S, + wgi: WGI, + levels: levelIdx, + total_pair_adds: totalPairAdds, + buckets_checked: checks.length, + buckets_passed: passCount, + first_mismatches: mismatches, + all_passed: allPassed, + gpu_wall_ms: gpuWall, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + if (qp.get('n')) NPTS = parseInt(qp.get('n')!, 10); + if (qp.get('buckets')) BUCKETS = parseInt(qp.get('buckets')!, 10); + if (qp.get('s')) S = parseInt(qp.get('s')!, 10); + if (qp.get('wgi')) WGI = parseInt(qp.get('wgi')!, 10); + if (qp.get('seed')) SEED = parseInt(qp.get('seed')!, 10); + return { n: NPTS, buckets: BUCKETS, s: S, wgi: WGI, seed: SEED }; +} + +async function main() { + try { + if (!('gpu' in navigator)) throw new Error('navigator.gpu missing'); + const params = parseParams(); + benchState.params = params; + log('info', `params: ${JSON.stringify(params)}`); + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + const Rinv = modInverse(R, p); + const sm = new ShaderManager(1, BUCKETS, BN254_CURVE_CONFIG, false); + const r = await runOracle(device, sm, R, Rinv, p); + benchState.results.push(r); + resultsClient.postProgress({ + kind: 'oracle_done', + all_passed: r.all_passed, + buckets_passed: r.buckets_passed, + buckets_checked: r.buckets_checked, + gpu_wall_ms: r.gpu_wall_ms, + }); + benchState.state = 'done'; + log('ok', 'done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + }) + .finally(() => { + postFinal().catch(() => {}); + }); diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.html new file mode 100644 index 000000000000..bd33e10887af --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.html @@ -0,0 +1,22 @@ + + + + + Bin-packed pair-tree MSM bucket-accumulate v2 (WebGPU) + + + +

Bin-packed pair-tree MSM bucket-accumulate v2 (WebGPU)

+

Query params: ?reps=R&n=N&buckets=B&s=S&wgi=W

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts new file mode 100644 index 000000000000..b0f97f6081d8 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v2.ts @@ -0,0 +1,566 @@ +/// +// bench-msm-tree-v2 — bin-packed pair-tree MSM bucket-accumulate with +// carry-forward. Eliminates the slow per-bucket tail kernel by packing +// pairs from any combination of buckets into the same chunk. +// +// For each (chunk t, slot k), the disjoint kernel sums (P_{2k}, P_{2k+1}). +// Both operands of each pair come from the SAME bucket; different +// (chunk, slot) entries can come from DIFFERENT buckets. The planner +// guarantees the within-pair bucket invariant. +// +// Per level transition: +// 1. host: per-bucket pair-count + carry, bin-pack into chunks of S. +// 2. marshal-pairs (GPU): gather operands per chunk_plan into chain_buf. +// 3. tree-disjoint (GPU, final=1): chain_buf -> simple strided output. +// 4. scatter-pairs (GPU): outputs -> active_sums_new at per-bucket positions. +// 5. carry-copy (GPU): odd-count carries -> active_sums_new (if any). +// 6. swap active_sums buffers, update counts/offsets. +// +// Terminate when max bucket count == 1. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { makeResultsClient } from './results_post.js'; + +const PG = 2; +const DEFAULT_N = 1 << 17; +const DEFAULT_BUCKETS = 1 << 12; +const DEFAULT_S = 16; +const DEFAULT_WGI = 64; + +let NPTS = DEFAULT_N; +let BUCKETS = DEFAULT_BUCKETS; +let S = DEFAULT_S; +let WGI = DEFAULT_WGI; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + for (;;) { + let v = 0n; + for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff); + v &= (1n << BigInt(bitlen)) - 1n; + if (v > 0n && v < p) return v; + } +} + +function bigintToPackedU32x8(v: bigint): Uint32Array { + const w = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + w[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return w; +} + +function makeSoABuf(device: GPUDevice, M: number, copyDst: boolean, copySrc: boolean): GPUBuffer { + const bytes = 2 * PG * M * 4 * 4; + let usage = GPUBufferUsage.STORAGE; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + return device.createBuffer({ size: bytes, usage }); +} + +// Build initial active_sums (Level 0). Points are assigned to random +// buckets and laid out bucket-major (bucket b's points at active_sums +// indices offsets[b] .. offsets[b]+counts[b]-1). The last 2 slots [M-2, +// M-1] hold a "pad pair" with distinct x — used to fill chunk-tail +// slots without divide-by-zero. +function buildL0ActiveSums(N: number, B: number, R: bigint, p: bigint, rng: () => number) { + const M = N + 2; + const buf = new Uint32Array(2 * PG * M * 4); + // Generate N + 2 random points. + const xWords = new Uint32Array(8 * M); + const yWords = new Uint32Array(8 * M); + for (let i = 0; i < M; i++) { + const x = (randomBelow(p, rng) * R) % p; + const y = (randomBelow(p, rng) * R) % p; + xWords.set(bigintToPackedU32x8(x), 8 * i); + yWords.set(bigintToPackedU32x8(y), 8 * i); + } + // Ensure pad pair x's differ. + if (xWords[8 * (M - 2)] === xWords[8 * (M - 1)]) { + xWords[8 * (M - 1)] ^= 1; + } + // Bucket assignment via composed hi/lo (unsigned). + const bucket = new Uint32Array(N); + const counts = new Uint32Array(B); + for (let i = 0; i < N; i++) { + const hi = (rng() >>> 16) & 0xffff; + const lo = (rng() >>> 16) & 0xffff; + const v = hi * 0x10000 + lo; + const b = v % B; + bucket[i] = b; + counts[b]++; + } + const offsets = new Uint32Array(B + 1); + for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b]; + const cursor = new Uint32Array(B); + const writeElem = (planeIdx: number, dstIdx: number, words: Uint32Array, srcOff: number) => { + for (let v = 0; v < PG; v++) { + const base = ((planeIdx * PG + v) * M + dstIdx) * 4; + buf[base + 0] = words[srcOff + 4 * v + 0]; + buf[base + 1] = words[srcOff + 4 * v + 1]; + buf[base + 2] = words[srcOff + 4 * v + 2]; + buf[base + 3] = words[srcOff + 4 * v + 3]; + } + }; + for (let i = 0; i < N; i++) { + const b = bucket[i]; + const dst = offsets[b] + cursor[b]++; + writeElem(0, dst, xWords, 8 * i); + writeElem(1, dst, yWords, 8 * i); + } + // Pad pair at indices M-2, M-1. + writeElem(0, M - 2, xWords, 8 * (M - 2)); + writeElem(1, M - 2, yWords, 8 * (M - 2)); + writeElem(0, M - 1, xWords, 8 * (M - 1)); + writeElem(1, M - 1, yWords, 8 * (M - 1)); + return { initBuf: buf, initCounts: counts, initOffsets: offsets, M }; +} + +// Bin-pack the per-bucket pairs into chunks of S. Returns the per-chunk +// operand-index plan and the per-output destination plan, plus the +// carry plan and next-level (counts, offsets). +function buildLevelPlan( + counts: Uint32Array, + offsets: Uint32Array, + s: number, + padLIdx: number, + padRIdx: number, + discardIdx: number, +) { + const B = counts.length; + let totalPairs = 0; + let totalCarries = 0; + const newCounts = new Uint32Array(B); + for (let b = 0; b < B; b++) { + const n = counts[b]; + const p = Math.floor(n / 2); + const c = n & 1; + totalPairs += p; + totalCarries += c; + newCounts[b] = p + c; + } + const newOffsets = new Uint32Array(B + 1); + for (let b = 0; b < B; b++) newOffsets[b + 1] = newOffsets[b] + newCounts[b]; + + const numChunks = Math.max(1, Math.ceil(totalPairs / s)); + const chunkPlan = new Uint32Array(2 * s * numChunks); + const scatterPlan = new Uint32Array(s * numChunks); + const carryPlan = new Uint32Array(2 * Math.max(1, totalCarries)); + + for (let i = 0; i < numChunks * s; i++) { + chunkPlan[2 * i + 0] = padLIdx; + chunkPlan[2 * i + 1] = padRIdx; + scatterPlan[i] = discardIdx; + } + + let slot = 0; + let carryIdx = 0; + for (let b = 0; b < B; b++) { + const n = counts[b]; + const p = Math.floor(n / 2); + for (let j = 0; j < p; j++) { + chunkPlan[2 * slot + 0] = offsets[b] + 2 * j; + chunkPlan[2 * slot + 1] = offsets[b] + 2 * j + 1; + scatterPlan[slot] = newOffsets[b] + j; + slot++; + } + if (n & 1) { + carryPlan[2 * carryIdx + 0] = offsets[b] + n - 1; + carryPlan[2 * carryIdx + 1] = newOffsets[b] + p; + carryIdx++; + } + } + return { chunkPlan, scatterPlan, carryPlan, newCounts, newOffsets, numChunks, numCarries: totalCarries, totalPairs }; +} + +interface LevelTiming { + T: number; + pairs: number; + carries: number; + marshal_ms: number; + disjoint_ms: number; + scatter_ms: number; + carry_ms: number; +} + +interface RunResult { + s: number; + wgi: number; + pairs: number; + buckets: number; + levels: number; + total_pair_adds: number; + total_wall_ms: number; + level_timings: LevelTiming[]; + ns_per_inpt: number; + sanity_ok: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { reps: number; n: number; buckets: number; s: number; wgi: number } | null; + results: RunResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] }; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const resultsClient = makeResultsClient({ page: 'bench-msm-tree-v2' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, params: benchState.params, results: benchState.results, + error: benchState.error, log: benchState.log, + userAgent: navigator.userAgent, hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-msm-tree-v2] ${msg}`); +} + +async function compileOne(device: GPUDevice, code: string, key: string, layout: GPUBindGroupLayout): Promise { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const m of info.messages) { + const line = `[shader ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`; + if (m.type === 'error') { console.error(line); log('err', line); errLines.push(line); hasError = true; } + else { console.warn(line); } + } + if (hasError) throw new Error(`WGSL compile failed for ${key}: ${errLines.slice(0, 4).join(' | ')}`); + return device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); +} + +function ioLayout4(device: GPUDevice): GPUBindGroupLayout { + return device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); +} + +async function readNonZero(device: GPUDevice, buf: GPUBuffer, u32Count: number): Promise { + const bytes = u32Count * 4; + const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, bytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const u32 = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + for (let i = 0; i < u32.length; i++) if (u32[i] !== 0) return true; + return false; +} + +interface PassSpec { pipeline: GPUComputePipeline; bind: GPUBindGroup; numWgs: number } + +// Encode multiple passes into one command encoder, submit once, await +// once. Returns the total wall time. Submit-overhead is paid once +// across all passes — the right way to measure a fused pipeline. +async function timeBatched(device: GPUDevice, passes: PassSpec[]): Promise { + const enc = device.createCommandEncoder(); + for (const p of passes) { + const pass = enc.beginComputePass(); + pass.setPipeline(p.pipeline); + pass.setBindGroup(0, p.bind); + pass.dispatchWorkgroups(p.numWgs, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + return performance.now() - t0; +} + +async function runPipeline(device: GPUDevice, sm: ShaderManager, reps: number, R: bigint, p: bigint): Promise { + log('info', `=== N=${NPTS} B=${BUCKETS} S=${S} WGI=${WGI}`); + + const rng = makeRng(0x9111); + const { initBuf, initCounts, initOffsets, M } = buildL0ActiveSums(NPTS, BUCKETS, R, p, rng); + log('info', `built L0 active_sums: M=${M}`); + + // Histogram peek + let maxC0 = 0, minC0 = NPTS, c0 = 0, smallC = 0; + for (let b = 0; b < initCounts.length; b++) { + if (initCounts[b] > maxC0) maxC0 = initCounts[b]; + if (initCounts[b] < minC0) minC0 = initCounts[b]; + if (initCounts[b] === 0) c0++; + if (initCounts[b] < 32) smallC++; + } + log('info', `bucket counts: min=${minC0} max=${maxC0} zero=${c0} small(<32)=${smallC}/${initCounts.length}`); + + const padLIdx = M - 2; + const padRIdx = M - 1; + // Use a fixed discard slot: M-2 (same as padLIdx). The discarded output + // overwrites the pad pair on each level — fine because the pad pair is + // re-set by buildL0 once and never relied on for correctness; in + // subsequent levels the planner's pad selection still finds a valid + // distinct-x pair as long as M-2 and M-1 hold distinct-x data at start. + // Per level, we re-seed the pad slots by copying first two slots of + // active_sums_new... actually simpler: use NEW pad slots per level. We + // need slots that have distinct x in active_sums_new. For pure safety, + // we'll have the planner discard scatters always go to (M-2) and pad + // chunks always read from active_sums_OLD's (padLIdx, padRIdx) — which + // is still ping-pong-stable if we maintain pad slots in both buffers. + const discardIdx = M - 2; + + // Two ping-pong active_sums buffers, sized M. + const bufA = makeSoABuf(device, M, true, true); + const bufB = makeSoABuf(device, M, true, true); + device.queue.writeBuffer(bufA, 0, initBuf); + // Mirror the pad pair into bufB so it's available when we ping-pong. + // Read M-2 and M-1 from initBuf and write them into bufB at the same slots. + const padPairBytes = new Uint32Array(2 * PG * 2 * 4); + for (let pl = 0; pl < 2; pl++) { + for (let v = 0; v < PG; v++) { + const baseSrc = ((pl * PG + v) * M + (M - 2)) * 4; + const baseDst = (pl * PG + v) * 2 * 4 + 0; + padPairBytes[baseDst + 0] = initBuf[baseSrc + 0]; + padPairBytes[baseDst + 1] = initBuf[baseSrc + 1]; + padPairBytes[baseDst + 2] = initBuf[baseSrc + 2]; + padPairBytes[baseDst + 3] = initBuf[baseSrc + 3]; + const baseSrc2 = ((pl * PG + v) * M + (M - 1)) * 4; + padPairBytes[baseDst + 4] = initBuf[baseSrc2 + 0]; + padPairBytes[baseDst + 5] = initBuf[baseSrc2 + 1]; + padPairBytes[baseDst + 6] = initBuf[baseSrc2 + 2]; + padPairBytes[baseDst + 7] = initBuf[baseSrc2 + 3]; + } + } + // padPairBytes layout: same SoA but with M=2. Write into both bufA pad + // region and bufB pad region. bufA already has them via initBuf; for + // bufB write a sparse pad pair via a small upload at the pad offset. + // Simpler: write the entire initBuf into bufB too (initial state matches). + device.queue.writeBuffer(bufB, 0, initBuf); + + // Scratch buffers. + const maxL0Chunks = Math.ceil(NPTS / 2 / S) + 1; + const chainBuf = makeSoABuf(device, 2 * S * maxL0Chunks, false, false); + const tempOutBuf = makeSoABuf(device, S * maxL0Chunks, false, true); + + // Compile all 4 pipelines. + const layoutMarshal = ioLayout4(device); + const layoutDisjoint = ioLayout4(device); + const layoutScatter = ioLayout4(device); + const layoutCarry = ioLayout4(device); + const marshalPipe = await compileOne(device, sm.gen_ba_marshal_pairs_bench_shader(WGI, S), `marshal-pairs-W${WGI}-S${S}`, layoutMarshal); + const disjointPipe = await compileOne(device, sm.gen_ba_pair_disjoint_tree_bench_shader(WGI, S), `disjoint-W${WGI}-S${S}`, layoutDisjoint); + const scatterPipe = await compileOne(device, sm.gen_ba_scatter_pairs_bench_shader(WGI, S), `scatter-pairs-W${WGI}-S${S}`, layoutScatter); + const carryPipe = await compileOne(device, sm.gen_ba_carry_copy_bench_shader(WGI), `carry-W${WGI}`, layoutCarry); + log('info', '4 pipelines compiled'); + + // Iterate. + let counts = initCounts; + let offsets = initOffsets; + let curIn: GPUBuffer = bufA; + let curOut: GPUBuffer = bufB; + let totalPairAdds = 0; + let levelIdx = 0; + const levelTimings: LevelTiming[] = []; + const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE }); + // Pre-pass: enqueue plan uploads and encode all-level passes. The + // device queue processes the writeBuffer calls in order before the + // single submit; the GPU inserts barriers between dependent storage + // reads/writes within the encoder. One submit + one await amortises + // submit overhead across the entire bucket-accumulate. + const allPasses: PassSpec[] = []; + + const startTime = performance.now(); + + for (;;) { + let maxCount = 0; + for (let b = 0; b < counts.length; b++) if (counts[b] > maxCount) maxCount = counts[b]; + if (maxCount <= 1) break; + if (levelIdx > 24) throw new Error('exceeded safety level cap'); + + const plan = buildLevelPlan(counts, offsets, S, padLIdx, padRIdx, discardIdx); + totalPairAdds += plan.totalPairs; + const T = plan.numChunks; + const numWgs = Math.ceil(T / WGI); + log('info', `L${levelIdx}: T=${T} pairs=${plan.totalPairs} carries=${plan.numCarries} maxCount=${maxCount}`); + + const chunkPlanBuf = device.createBuffer({ size: plan.chunkPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(chunkPlanBuf, 0, plan.chunkPlan); + const scatterPlanBuf = device.createBuffer({ size: plan.scatterPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(scatterPlanBuf, 0, plan.scatterPlan); + const carryPlanBuf = device.createBuffer({ size: plan.carryPlan.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(carryPlanBuf, 0, plan.carryPlan); + + const marshalParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(marshalParams, 0, new Uint32Array([T, M, 0, 0])); + const disjointParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(disjointParams, 0, new Uint32Array([2 * S * T, T, 1, 0])); // final_flag=1 + const scatterParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(scatterParams, 0, new Uint32Array([T, M, 0, 0])); + const carryParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + if (plan.numCarries > 0) { + device.queue.writeBuffer(carryParams, 0, new Uint32Array([plan.numCarries, M, M, 0])); + } + + const marshalBind = device.createBindGroup({ + layout: layoutMarshal, entries: [ + { binding: 0, resource: { buffer: chunkPlanBuf } }, + { binding: 1, resource: { buffer: curIn } }, + { binding: 2, resource: { buffer: chainBuf } }, + { binding: 3, resource: { buffer: marshalParams } }, + ], + }); + const disjointBind = device.createBindGroup({ + layout: layoutDisjoint, entries: [ + { binding: 0, resource: { buffer: chainBuf } }, + { binding: 1, resource: { buffer: dummy } }, + { binding: 2, resource: { buffer: tempOutBuf } }, + { binding: 3, resource: { buffer: disjointParams } }, + ], + }); + const scatterBind = device.createBindGroup({ + layout: layoutScatter, entries: [ + { binding: 0, resource: { buffer: scatterPlanBuf } }, + { binding: 1, resource: { buffer: tempOutBuf } }, + { binding: 2, resource: { buffer: curOut } }, + { binding: 3, resource: { buffer: scatterParams } }, + ], + }); + let carryBind: GPUBindGroup | null = null; + if (plan.numCarries > 0) { + carryBind = device.createBindGroup({ + layout: layoutCarry, entries: [ + { binding: 0, resource: { buffer: carryPlanBuf } }, + { binding: 1, resource: { buffer: curIn } }, + { binding: 2, resource: { buffer: curOut } }, + { binding: 3, resource: { buffer: carryParams } }, + ], + }); + } + + // Stash level's passes into the outer pass list to be timed as a + // single batched submit across ALL levels. + allPasses.push({ pipeline: marshalPipe, bind: marshalBind, numWgs }); + allPasses.push({ pipeline: disjointPipe, bind: disjointBind, numWgs }); + allPasses.push({ pipeline: scatterPipe, bind: scatterBind, numWgs }); + if (plan.numCarries > 0 && carryBind) { + const carryWgs = Math.ceil(plan.numCarries / WGI); + allPasses.push({ pipeline: carryPipe, bind: carryBind, numWgs: carryWgs }); + } + levelTimings.push({ + T, pairs: plan.totalPairs, carries: plan.numCarries, + marshal_ms: 0, disjoint_ms: 0, scatter_ms: 0, carry_ms: 0, + }); + log('info', ` L${levelIdx} encoded (T=${T}, pairs=${plan.totalPairs})`); + + // Cleanup level-local buffers. + chunkPlanBuf.destroy(); + scatterPlanBuf.destroy(); + carryPlanBuf.destroy(); + marshalParams.destroy(); + disjointParams.destroy(); + scatterParams.destroy(); + carryParams.destroy(); + + counts = plan.newCounts; + offsets = plan.newOffsets; + [curIn, curOut] = [curOut, curIn]; + levelIdx++; + } + + // Single batched submit for ALL levels. + const wallSubmit = performance.now(); + const totalWall = await timeBatched(device, allPasses); + log('info', `batched ${allPasses.length} passes in one submit: ${totalWall.toFixed(2)}ms`); + + const wall = performance.now() - startTime; // includes plan-build + upload + GPU time + const sanity = await readNonZero(device, curIn, 8); + + bufA.destroy(); + bufB.destroy(); + chainBuf.destroy(); + tempOutBuf.destroy(); + dummy.destroy(); + + const nsPerInpt = (totalWall * 1e6) / NPTS; + log( + sanity ? 'ok' : 'err', + `pipeline: ${levelIdx} levels, ${totalPairAdds} pair-adds, single-submit GPU wall=${totalWall.toFixed(2)}ms, total incl plan-upload=${wall.toFixed(2)}ms, ns/in-pt=${nsPerInpt.toFixed(2)}, sanity=${sanity ? 'OK' : 'FAIL'}`, + ); + + return { + s: S, wgi: WGI, pairs: NPTS, buckets: BUCKETS, levels: levelIdx, + total_pair_adds: totalPairAdds, total_wall_ms: totalWall, level_timings: levelTimings, + ns_per_inpt: nsPerInpt, sanity_ok: sanity, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + const reps = parseInt(qp.get('reps') ?? '3', 10); + if (!Number.isFinite(reps) || reps <= 0 || reps > 50) throw new Error(`?reps must be in (0, 50]`); + if (qp.get('n')) NPTS = parseInt(qp.get('n')!, 10); + if (qp.get('buckets')) BUCKETS = parseInt(qp.get('buckets')!, 10); + if (qp.get('s')) S = parseInt(qp.get('s')!, 10); + if (qp.get('wgi')) WGI = parseInt(qp.get('wgi')!, 10); + return { reps, n: NPTS, buckets: BUCKETS, s: S, wgi: WGI }; +} + +async function main() { + try { + if (!('gpu' in navigator)) throw new Error('navigator.gpu missing'); + const params = parseParams(); + benchState.params = params; + log('info', `params: reps=${params.reps} n=${params.n} buckets=${params.buckets} s=${params.s} wgi=${params.wgi}`); + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + const sm = new ShaderManager(4, NPTS, BN254_CURVE_CONFIG, false); + const r = await runPipeline(device, sm, params.reps, R, p); + benchState.results.push(r); + resultsClient.postProgress({ kind: 'pipeline_done', ns_per_inpt: r.ns_per_inpt, sanity_ok: r.sanity_ok }); + benchState.state = 'done'; + log('ok', 'done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { const msg = e instanceof Error ? e.message : String(e); log('err', `unhandled: ${msg}`); benchState.state = 'error'; benchState.error = msg; }) + .finally(() => { postFinal().catch(() => {}); }); diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.html new file mode 100644 index 000000000000..8eaa024c2f13 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.html @@ -0,0 +1,22 @@ + + + + + v3 — GPU planner + fused super-kernel MSM bucket-accumulate + + + +

v3 — GPU planner + fused super-kernel MSM bucket-accumulate

+

Query params: ?n=N&buckets=B&s=S&wgi=W&levels=L

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.ts new file mode 100644 index 000000000000..4b420b00ee7d --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree-v3.ts @@ -0,0 +1,538 @@ +/// +// bench-msm-tree-v3 — GPU-side planner + fused super-kernel. +// +// Per level (all GPU, encoded into one command list): +// 1. Reset totals atomic counter to 0. +// 2. Pre-pad chunk_plan + scatter_plan + carry_plan to safe values. +// 3. Planner kernel: 1 thread per bucket -> writes chunk_plan, +// scatter_plan, carry_plan, new_counts, new_offsets via atomic +// offset reservation. +// 4. Fused super-kernel: marshal + disjoint + scatter in one pass. +// Reads chunk_plan + scatter_plan + active_sums_old; writes +// active_sums_new. +// 5. Carry kernel: copies odd-count carries from active_sums_old to +// active_sums_new. +// 6. Swap active_sums buffers, swap counts/offsets buffers. +// +// Over-dispatch L_MAX levels (default 8 for Poisson(λ=32) where max +// bucket count is ~50-60, requiring log2(60) = 6 levels). Extra levels +// with all-count-1 input are no-ops at the kernel level (planner +// produces zero pairs; fused kernel dispatched with 0 threads via +// host-side numWgs=0; carry kernel just copies the single-element- +// per-bucket data forward). +// +// Single submit across all 3*L_MAX kernel dispatches. Zero host-GPU +// round-trips between scalar-decompose and final readback. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { makeResultsClient } from './results_post.js'; + +const PG = 2; +const DEFAULT_N = 1 << 17; +const DEFAULT_BUCKETS = 1 << 12; +const DEFAULT_S = 16; +const DEFAULT_WGI = 64; +const DEFAULT_LEVELS = 8; + +let NPTS = DEFAULT_N; +let BUCKETS = DEFAULT_BUCKETS; +let S = DEFAULT_S; +let WGI = DEFAULT_WGI; +let LEVELS = DEFAULT_LEVELS; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + for (;;) { + let v = 0n; + for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff); + v &= (1n << BigInt(bitlen)) - 1n; + if (v > 0n && v < p) return v; + } +} + +function bigintToPackedU32x8(v: bigint): Uint32Array { + const w = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + w[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return w; +} + +// Build initial active_sums and per-bucket counts/offsets (level 0). +function buildL0(N: number, B: number, R: bigint, p: bigint, rng: () => number) { + const M = N + 2; + const buf = new Uint32Array(2 * PG * M * 4); + const xWords = new Uint32Array(8 * M); + const yWords = new Uint32Array(8 * M); + for (let i = 0; i < M; i++) { + const x = (randomBelow(p, rng) * R) % p; + const y = (randomBelow(p, rng) * R) % p; + xWords.set(bigintToPackedU32x8(x), 8 * i); + yWords.set(bigintToPackedU32x8(y), 8 * i); + } + if (xWords[8 * (M - 2)] === xWords[8 * (M - 1)]) xWords[8 * (M - 1)] ^= 1; + const bucket = new Uint32Array(N); + const counts = new Uint32Array(B); + for (let i = 0; i < N; i++) { + const hi = (rng() >>> 16) & 0xffff; + const lo = (rng() >>> 16) & 0xffff; + const v = hi * 0x10000 + lo; + const b = v % B; + bucket[i] = b; + counts[b]++; + } + const offsets = new Uint32Array(B + 1); + for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b]; + const cursor = new Uint32Array(B); + const writeElem = (planeIdx: number, dstIdx: number, words: Uint32Array, srcOff: number) => { + for (let v = 0; v < PG; v++) { + const base = ((planeIdx * PG + v) * M + dstIdx) * 4; + buf[base + 0] = words[srcOff + 4 * v + 0]; + buf[base + 1] = words[srcOff + 4 * v + 1]; + buf[base + 2] = words[srcOff + 4 * v + 2]; + buf[base + 3] = words[srcOff + 4 * v + 3]; + } + }; + for (let i = 0; i < N; i++) { + const b = bucket[i]; + const dst = offsets[b] + cursor[b]++; + writeElem(0, dst, xWords, 8 * i); + writeElem(1, dst, yWords, 8 * i); + } + writeElem(0, M - 2, xWords, 8 * (M - 2)); + writeElem(1, M - 2, yWords, 8 * (M - 2)); + writeElem(0, M - 1, xWords, 8 * (M - 1)); + writeElem(1, M - 1, yWords, 8 * (M - 1)); + return { initBuf: buf, initCounts: counts, initOffsets: offsets, M }; +} + +function makeSoABuf(device: GPUDevice, M: number, copyDst: boolean, copySrc: boolean): GPUBuffer { + const bytes = 2 * PG * M * 4 * 4; + let usage = GPUBufferUsage.STORAGE; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + return device.createBuffer({ size: bytes, usage }); +} + +interface RunResult { + s: number; + wgi: number; + pairs: number; + buckets: number; + levels_run: number; + gpu_wall_ms: number; + ns_per_inpt: number; + sanity_ok: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { n: number; buckets: number; s: number; wgi: number; levels: number } | null; + results: RunResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] }; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const resultsClient = makeResultsClient({ page: 'bench-msm-tree-v3' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, params: benchState.params, results: benchState.results, + error: benchState.error, log: benchState.log, + userAgent: navigator.userAgent, hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-msm-tree-v3] ${msg}`); +} + +async function compileOne(device: GPUDevice, code: string, key: string, layout: GPUBindGroupLayout): Promise { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const m of info.messages) { + const line = `[shader ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`; + if (m.type === 'error') { console.error(line); log('err', line); errLines.push(line); hasError = true; } + else { console.warn(line); } + } + if (hasError) throw new Error(`WGSL compile failed for ${key}: ${errLines.slice(0, 4).join(' | ')}`); + return device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); +} + +async function readNonZero(device: GPUDevice, buf: GPUBuffer, u32Count: number): Promise { + const bytes = u32Count * 4; + const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, bytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const u32 = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + for (let i = 0; i < u32.length; i++) if (u32[i] !== 0) return true; + return false; +} + +async function runPipeline(device: GPUDevice, sm: ShaderManager, R: bigint, p: bigint): Promise { + log('info', `=== N=${NPTS} B=${BUCKETS} S=${S} WGI=${WGI} LEVELS=${LEVELS}`); + + const rng = makeRng(0xc711); + const { initBuf, initCounts, initOffsets, M } = buildL0(NPTS, BUCKETS, R, p, rng); + log('info', `L0 active_sums: M=${M}, B=${BUCKETS}`); + + // Histogram peek + let maxC = 0, minC = NPTS, smallC = 0; + for (let b = 0; b < initCounts.length; b++) { + if (initCounts[b] > maxC) maxC = initCounts[b]; + if (initCounts[b] < minC) minC = initCounts[b]; + if (initCounts[b] < 32) smallC++; + } + log('info', `bucket counts: min=${minC} max=${maxC} small(<32)=${smallC}/${BUCKETS}`); + + // Plan-buffer sizing — must accommodate L0 max chunks. + const MAX_CHUNKS = Math.ceil(NPTS / 2 / S) + 16; + const MAX_PAIR_SLOTS = MAX_CHUNKS * S; + const MAX_CARRIES = BUCKETS; + + // Host simulates the bin-packing iteration to compute the right + // dispatch size per level. The GPU planner does the same work to + // fill plan buffers; this host loop only computes sizes (T_chunks, + // T_carries) so the host can dispatch the fused + carry kernels at + // the correct size per level, avoiding pad-chunk waste. + // + // Without this, the fused kernel runs MAX_CHUNKS=4160 threads per + // level regardless of actual work. At L5 with only ~117 real chunks + // that's ~3979 pad-chunks each running a full fr_inv_by_a + S mont + // muls -- dominates wall time. + const perLevelTChunks: number[] = []; + const perLevelTCarries: number[] = []; + { + let cur = new Uint32Array(initCounts); + for (let lv = 0; lv < LEVELS; lv++) { + let totalPairs = 0; + let totalCarries = 0; + const next = new Uint32Array(cur.length); + for (let b = 0; b < cur.length; b++) { + const n = cur[b]; + const p = (n / 2) | 0; + const c = n & 1; + totalPairs += p; + totalCarries += c; + next[b] = p + c; + } + perLevelTChunks.push(Math.max(1, Math.ceil(totalPairs / S))); + perLevelTCarries.push(Math.max(1, totalCarries)); + cur = next; + } + log('info', `host-simulated per-level T_chunks: ${perLevelTChunks.join(', ')}`); + log('info', `host-simulated per-level T_carries: ${perLevelTCarries.join(', ')}`); + } + + const mkStorage = (bytes: number, copyDst = true, copySrc = false): GPUBuffer => { + let usage = GPUBufferUsage.STORAGE; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + return device.createBuffer({ size: bytes, usage }); + }; + + // Ping-pong active_sums. + const bufA = makeSoABuf(device, M, true, true); + const bufB = makeSoABuf(device, M, true, true); + device.queue.writeBuffer(bufA, 0, initBuf); + device.queue.writeBuffer(bufB, 0, initBuf); // mirror initial for pad-pair availability + + // Plan buffers (reused per level). + const chunkPlanBuf = mkStorage(2 * MAX_PAIR_SLOTS * 4); + const scatterPlanBuf = mkStorage(MAX_PAIR_SLOTS * 4); + const carryPlanBuf = mkStorage(2 * MAX_CARRIES * 4); + + // Per-level counts/offsets buffers (ping-pong). + const countsA = mkStorage(BUCKETS * 4); + const countsB = mkStorage(BUCKETS * 4); + const offsetsA = mkStorage((BUCKETS + 1) * 4); + const offsetsB = mkStorage((BUCKETS + 1) * 4); + device.queue.writeBuffer(countsA, 0, initCounts); + device.queue.writeBuffer(offsetsA, 0, initOffsets); + + // Totals atomic counter [pair_off_accum, carry_off_accum, new_off_accum, _] + const totalsBuf = device.createBuffer({ + size: 16, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + + // Pre-padded chunk_plan / scatter_plan / carry_plan templates. + // Pad slots all point to safe values: + // chunk_plan pad pair = (M-2, M-1) — known distinct-x in active_sums + // scatter_plan pad dst = M-2 — discard target (within active_sums_new) + // carry_plan pad src = M-2, dst = M-2 — no-op self-copy of pad slot + const padChunkPlan = new Uint32Array(2 * MAX_PAIR_SLOTS); + const padScatterPlan = new Uint32Array(MAX_PAIR_SLOTS); + const padCarryPlan = new Uint32Array(2 * MAX_CARRIES); + for (let i = 0; i < MAX_PAIR_SLOTS; i++) { + padChunkPlan[2 * i + 0] = M - 2; + padChunkPlan[2 * i + 1] = M - 1; + padScatterPlan[i] = M - 2; + } + for (let i = 0; i < MAX_CARRIES; i++) { + padCarryPlan[2 * i + 0] = M - 2; + padCarryPlan[2 * i + 1] = M - 2; + } + + // Per-level params (reused). + const plannerParams = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + const fusedParams: GPUBuffer[] = []; + const carryParams: GPUBuffer[] = []; + for (let i = 0; i < LEVELS; i++) { + fusedParams.push(device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST })); + carryParams.push(device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST })); + } + device.queue.writeBuffer(plannerParams, 0, new Uint32Array([BUCKETS, S, 0, 0])); + + // Layouts. + const plannerLayout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 5, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 6, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 7, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 8, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const fusedLayout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const carryLayout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + + const plannerPipe = await compileOne(device, sm.gen_ba_planner_bench_shader(WGI, S, 64), `planner-W${WGI}-S${S}`, plannerLayout); + const fusedPipe = await compileOne(device, sm.gen_ba_fused_super_bench_shader(WGI, S), `fused-W${WGI}-S${S}`, fusedLayout); + const carryPipe = await compileOne(device, sm.gen_ba_carry_copy_bench_shader(WGI), `carry-W${WGI}`, carryLayout); + log('info', '3 pipelines compiled'); + + // Encode all level passes into one command encoder. + // Per level k: + // - clear totalsBuf to zero (writeBuffer queued before submit) + // - clear plan buffers to pad templates (writeBuffer queued) + // - dispatch planner + // - dispatch fused (numWgs = ceil(MAX_CHUNKS / WGI) — over-provisioned; + // idle threads early-out via if (t >= T) check) + // - dispatch carry (numWgs = ceil(MAX_CARRIES / WGI) — over-provisioned) + // - swap counts/offsets via bind group selection on next level + + const enc = device.createCommandEncoder(); + let curCountsIn: GPUBuffer = countsA; + let curCountsOut: GPUBuffer = countsB; + let curOffsetsIn: GPUBuffer = offsetsA; + let curOffsetsOut: GPUBuffer = offsetsB; + let curActiveIn: GPUBuffer = bufA; + let curActiveOut: GPUBuffer = bufB; + + const numWgsPlanner = Math.ceil(BUCKETS / WGI); + const numWgsFusedPerLevel = perLevelTChunks.map(t => Math.ceil(t / WGI)); + const numWgsCarryPerLevel = perLevelTCarries.map(t => Math.ceil(t / WGI)); + log('info', `numWgs: planner=${numWgsPlanner}, fused=${numWgsFusedPerLevel.join(',')}, carry=${numWgsCarryPerLevel.join(',')}`); + + // Per-level params with the right-sized T from the host bin-pack simulator. + for (let lv = 0; lv < LEVELS; lv++) { + device.queue.writeBuffer(fusedParams[lv], 0, new Uint32Array([perLevelTChunks[lv], M, M, 0])); + device.queue.writeBuffer(carryParams[lv], 0, new Uint32Array([perLevelTCarries[lv], M, M, 0])); + } + + // Pre-pad plan buffers (only need to do once; planner overwrites real + // entries each level, and the pre-pad is stable across levels because + // pad slots remain pad-valued). + device.queue.writeBuffer(chunkPlanBuf, 0, padChunkPlan); + device.queue.writeBuffer(scatterPlanBuf, 0, padScatterPlan); + device.queue.writeBuffer(carryPlanBuf, 0, padCarryPlan); + + // Bind groups built per level (to swap counts/offsets/active buffers). + for (let lv = 0; lv < LEVELS; lv++) { + // Reset totals atomic counter before this level's planner. + device.queue.writeBuffer(totalsBuf, 0, new Uint32Array([0, 0, 0, 0])); + // Re-pad plan buffers (planner overwrites only real entries; the + // pad regions get re-padded between levels to clean any leftover + // real entries from the prior level). + device.queue.writeBuffer(chunkPlanBuf, 0, padChunkPlan); + device.queue.writeBuffer(scatterPlanBuf, 0, padScatterPlan); + device.queue.writeBuffer(carryPlanBuf, 0, padCarryPlan); + + const plannerBind = device.createBindGroup({ + layout: plannerLayout, + entries: [ + { binding: 0, resource: { buffer: curCountsIn } }, + { binding: 1, resource: { buffer: curOffsetsIn } }, + { binding: 2, resource: { buffer: chunkPlanBuf } }, + { binding: 3, resource: { buffer: scatterPlanBuf } }, + { binding: 4, resource: { buffer: carryPlanBuf } }, + { binding: 5, resource: { buffer: totalsBuf } }, + { binding: 6, resource: { buffer: curCountsOut } }, + { binding: 7, resource: { buffer: curOffsetsOut } }, + { binding: 8, resource: { buffer: plannerParams } }, + ], + }); + const fusedBind = device.createBindGroup({ + layout: fusedLayout, + entries: [ + { binding: 0, resource: { buffer: chunkPlanBuf } }, + { binding: 1, resource: { buffer: scatterPlanBuf } }, + { binding: 2, resource: { buffer: curActiveIn } }, + { binding: 3, resource: { buffer: curActiveOut } }, + { binding: 4, resource: { buffer: fusedParams[lv] } }, + ], + }); + const carryBind = device.createBindGroup({ + layout: carryLayout, + entries: [ + { binding: 0, resource: { buffer: carryPlanBuf } }, + { binding: 1, resource: { buffer: curActiveIn } }, + { binding: 2, resource: { buffer: curActiveOut } }, + { binding: 3, resource: { buffer: carryParams[lv] } }, + ], + }); + + // Encode the 3 passes for this level. + { + const pass = enc.beginComputePass(); + pass.setPipeline(plannerPipe); + pass.setBindGroup(0, plannerBind); + pass.dispatchWorkgroups(numWgsPlanner, 1, 1); + pass.end(); + } + { + const pass = enc.beginComputePass(); + pass.setPipeline(fusedPipe); + pass.setBindGroup(0, fusedBind); + pass.dispatchWorkgroups(numWgsFusedPerLevel[lv], 1, 1); + pass.end(); + } + { + const pass = enc.beginComputePass(); + pass.setPipeline(carryPipe); + pass.setBindGroup(0, carryBind); + pass.dispatchWorkgroups(numWgsCarryPerLevel[lv], 1, 1); + pass.end(); + } + + // Swap for next level. + [curCountsIn, curCountsOut] = [curCountsOut, curCountsIn]; + [curOffsetsIn, curOffsetsOut] = [curOffsetsOut, curOffsetsIn]; + [curActiveIn, curActiveOut] = [curActiveOut, curActiveIn]; + } + + // Single submit + single await for the entire bucket-accumulate. + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + const gpuWall = performance.now() - t0; + + // Sanity: at least one element of the final active_sums must be non-zero. + const sanity = await readNonZero(device, curActiveIn, 8); + const nsPerInpt = (gpuWall * 1e6) / NPTS; + + log( + sanity ? 'ok' : 'err', + `v3 pipeline: ${LEVELS} levels over-dispatched, single submit GPU wall=${gpuWall.toFixed(2)}ms, ns/in-pt=${nsPerInpt.toFixed(2)}, sanity=${sanity ? 'OK' : 'FAIL'}`, + ); + + // Cleanup + bufA.destroy(); bufB.destroy(); + chunkPlanBuf.destroy(); scatterPlanBuf.destroy(); carryPlanBuf.destroy(); + countsA.destroy(); countsB.destroy(); offsetsA.destroy(); offsetsB.destroy(); + totalsBuf.destroy(); + plannerParams.destroy(); + for (const b of fusedParams) b.destroy(); + for (const b of carryParams) b.destroy(); + + return { + s: S, wgi: WGI, pairs: NPTS, buckets: BUCKETS, + levels_run: LEVELS, gpu_wall_ms: gpuWall, ns_per_inpt: nsPerInpt, sanity_ok: sanity, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + if (qp.get('n')) NPTS = parseInt(qp.get('n')!, 10); + if (qp.get('buckets')) BUCKETS = parseInt(qp.get('buckets')!, 10); + if (qp.get('s')) S = parseInt(qp.get('s')!, 10); + if (qp.get('wgi')) WGI = parseInt(qp.get('wgi')!, 10); + if (qp.get('levels')) LEVELS = parseInt(qp.get('levels')!, 10); + return { n: NPTS, buckets: BUCKETS, s: S, wgi: WGI, levels: LEVELS }; +} + +async function main() { + try { + if (!('gpu' in navigator)) throw new Error('navigator.gpu missing'); + const params = parseParams(); + benchState.params = params; + log('info', `params: n=${params.n} buckets=${params.buckets} s=${params.s} wgi=${params.wgi} levels=${params.levels}`); + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + const sm = new ShaderManager(4, NPTS, BN254_CURVE_CONFIG, false); + const r = await runPipeline(device, sm, R, p); + benchState.results.push(r); + resultsClient.postProgress({ kind: 'pipeline_done', ns_per_inpt: r.ns_per_inpt, sanity_ok: r.sanity_ok }); + benchState.state = 'done'; + log('ok', 'done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { const msg = e instanceof Error ? e.message : String(e); log('err', `unhandled: ${msg}`); benchState.state = 'error'; benchState.error = msg; }) + .finally(() => { postFinal().catch(() => {}); }); diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.html b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.html new file mode 100644 index 000000000000..f64e6d3ff15a --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.html @@ -0,0 +1,24 @@ + + + + + MSM bucket-accumulate multi-level pair-tree bench (WebGPU) + + + +

MSM bucket-accumulate multi-level pair-tree bench (WebGPU)

+

Query params: ?reps=R&n=N&buckets=B&s=S&wgi=W&mode=uniform|skewed&disp=D

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts new file mode 100644 index 000000000000..f8c05dfa79c3 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-msm-tree.ts @@ -0,0 +1,790 @@ +/// +// bench-msm-tree — end-to-end MSM bucket-accumulate benchmark for the +// multi-level pair-tree pipeline: +// 1. marshal-l0 (CSR + chunk_plan + point_pool -> strided chain_buf) +// 2. tree-disjoint level 0 (chain_buf -> level-1 input layout, in-place via ping-pong) +// 3. tree-disjoint level 1 (continues in ping-pong) +// ... +// level L-1 (final): tree-disjoint with `final` flag -> simple strided output +// 4. tail kernel (small buckets count<2*S -> one sum each) +// +// Reports per-stage and combined ns/in-pt for the full bucket-accumulate +// over N points distributed across B buckets. +// +// Modes: +// ?mode=uniform : every bucket has exactly 2*S = 32 points (clean +// multi-level test, no tail). +// ?mode=skewed : Poisson-distributed via uniform random scalar +// assignment. Main pair-tree handles buckets with +// count >= 2*S; tail kernel handles the rest. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; +import { makeResultsClient } from './results_post.js'; + +const PG = 2; +const DEFAULT_N = 1 << 17; // 131072 points +const DEFAULT_BUCKETS = 1 << 12; // 4096 buckets -> uniform avg 32 = 2*S +const DEFAULT_S = 16; +const DEFAULT_WGI = 64; +const DEFAULT_DISP = 4; // dispatch amortisation per timed sample +const DEFAULT_MODE = 'uniform' as const; + +let NPTS = DEFAULT_N; +let BUCKETS = DEFAULT_BUCKETS; +let S = DEFAULT_S; +let WGI = DEFAULT_WGI; +let DISP = DEFAULT_DISP; +let MODE: 'uniform' | 'skewed' = DEFAULT_MODE; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + for (;;) { + let v = 0n; + for (let i = 0; i < byteLen; i++) v = (v << 8n) | BigInt(rng() & 0xff); + v &= (1n << BigInt(bitlen)) - 1n; + if (v > 0n && v < p) return v; + } +} + +function bigintToPackedU32x8(v: bigint): Uint32Array { + const w = new Uint32Array(8); + let x = v; + for (let i = 0; i < 8; i++) { + w[i] = Number(x & 0xffffffffn); + x >>= 32n; + } + return w; +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +function buildPointPool(poolSize: number, R: bigint, p: bigint, rng: () => number): Uint32Array { + const M = poolSize; + const buf = new Uint32Array(2 * PG * M * 4); + for (let e = 0; e < M; e++) { + const x = (randomBelow(p, rng) * R) % p; + const y = (randomBelow(p, rng) * R) % p; + const wx = bigintToPackedU32x8(x); + const wy = bigintToPackedU32x8(y); + for (let v = 0; v < PG; v++) { + const baseX = ((0 * PG + v) * M + e) * 4; + const baseY = ((1 * PG + v) * M + e) * 4; + buf[baseX + 0] = wx[4 * v + 0]; + buf[baseX + 1] = wx[4 * v + 1]; + buf[baseX + 2] = wx[4 * v + 2]; + buf[baseX + 3] = wx[4 * v + 3]; + buf[baseY + 0] = wy[4 * v + 0]; + buf[baseY + 1] = wy[4 * v + 1]; + buf[baseY + 2] = wy[4 * v + 2]; + buf[baseY + 3] = wy[4 * v + 3]; + } + } + return buf; +} + +interface CSR { + csrIndices: Uint32Array; // 1-based: index 0 reserved (decoy/unused) + offsets: Uint32Array; + counts: Uint32Array; +} + +function buildUniformCSR(N: number, B: number, perBucket: number): CSR { + if (N !== B * perBucket) { + throw new Error(`uniform mode requires N=${N} = B*${perBucket}=${B * perBucket}`); + } + const counts = new Uint32Array(B); + const offsets = new Uint32Array(B + 1); + const csrIndices = new Uint32Array(N); + for (let b = 0; b < B; b++) { + counts[b] = perBucket; + offsets[b + 1] = offsets[b] + perBucket; + for (let i = 0; i < perBucket; i++) { + csrIndices[offsets[b] + i] = b * perBucket + i + 1; // 1-based + } + } + return { csrIndices, offsets, counts }; +} + +function buildSkewedCSR(N: number, B: number, rng: () => number): CSR { + const bucket = new Uint32Array(N); + const counts = new Uint32Array(B); + // LCG low bits have short periods (low 12 bits cycle every 4096 + // calls, which equals B in the default config), so direct modulo + // produces a degenerate "every bucket gets exactly N/B points" + // distribution. Compose a uniform 32-bit integer from the high + // halves of two RNG draws via unsigned arithmetic (multiplication + // to avoid the signed-i32 quirk of JS bitwise ops) and reduce. + for (let i = 0; i < N; i++) { + const hi = (rng() >>> 16) & 0xffff; + const lo = (rng() >>> 16) & 0xffff; + const v = hi * 0x10000 + lo; + const b = v % B; + bucket[i] = b; + counts[b]++; + } + const offsets = new Uint32Array(B + 1); + for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b]; + const cursor = new Uint32Array(B); + const csrIndices = new Uint32Array(N); + for (let i = 0; i < N; i++) { + const b = bucket[i]; + csrIndices[offsets[b] + cursor[b]++] = i + 1; // 1-based + } + return { csrIndices, offsets, counts }; +} + +// Split each bucket into (a) full 2*S chunks for the level-0 marshal + +// pair-tree pass and (b) a tail of count mod 2*S points for the tail +// kernel. Returns a chunk_plan for the main pipeline and a tail_plan +// for the tail kernel. +// +// Only buckets where count >= 2*S contribute to the main path. Each +// contributes floor(count / (2*S)) chunks of exactly 2*S points each. +// Remaining count mod 2*S points go to the tail. Buckets with +// count < 2*S go entirely to the tail. NOTE: for v1, the per-bucket +// "extra full-2S chunks already reduced via main path" is *not* +// re-folded into a single per-bucket sum; this would matter for +// correctness but for the bench we just measure dispatch wall-clock. +function buildChunkAndTailPlans( + offsets: Uint32Array, + counts: Uint32Array, + S: number, +): { chunkPlan: Uint32Array; T: number; tailPlan: Uint32Array; TT: number; mainPoints: number; tailPoints: number } { + const B = counts.length; + const blkSize = 2 * S; + let T = 0; + let TT = 0; + let mainPts = 0; + let tailPts = 0; + for (let b = 0; b < B; b++) { + const c = counts[b]; + const nMain = Math.floor(c / blkSize); + T += nMain; + mainPts += nMain * blkSize; + const remain = c - nMain * blkSize; + if (remain > 0) { + TT++; + tailPts += remain; + } + } + const chunkPlan = new Uint32Array(2 * T); + const tailPlan = new Uint32Array(3 * TT); + let t = 0; + let tt = 0; + for (let b = 0; b < B; b++) { + const c = counts[b]; + const nMain = Math.floor(c / blkSize); + for (let k = 0; k < nMain; k++) { + chunkPlan[2 * t + 0] = b; + chunkPlan[2 * t + 1] = offsets[b] + k * blkSize; + t++; + } + const remain = c - nMain * blkSize; + if (remain > 0) { + tailPlan[3 * tt + 0] = b; + tailPlan[3 * tt + 1] = offsets[b] + nMain * blkSize; + tailPlan[3 * tt + 2] = remain; + tt++; + } + } + return { chunkPlan, T, tailPlan, TT, mainPoints: mainPts, tailPoints: tailPts }; +} + +interface KernelTiming { + median_ms: number; + min_ms: number; + max_ms: number; + samples_ms: number[]; +} + +async function compile( + device: GPUDevice, + code: string, + cacheKey: string, + layout: GPUBindGroupLayout, +): Promise { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const m of info.messages) { + const line = `[shader ${cacheKey}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`; + if (m.type === 'error') { + console.error(line); + log('err', line); + errLines.push(line); + hasError = true; + } else { + console.warn(line); + } + } + if (hasError) throw new Error(`WGSL compile failed for ${cacheKey}: ${errLines.slice(0, 4).join(' | ')}`); + return device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); +} + +function marshalLayout(device: GPUDevice): GPUBindGroupLayout { + return device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); +} + +function treeKernelLayout(device: GPUDevice): GPUBindGroupLayout { + return device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); +} + +function tailLayout(device: GPUDevice): GPUBindGroupLayout { + return device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); +} + +async function readNonZero(device: GPUDevice, buf: GPUBuffer, u32Count: number): Promise { + const bytes = u32Count * 4; + const staging = device.createBuffer({ size: bytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, bytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const u32 = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + for (let i = 0; i < u32.length; i++) if (u32[i] !== 0) return true; + return false; +} + +async function timeDispatch( + device: GPUDevice, + pipeline: GPUComputePipeline, + bind: GPUBindGroup, + numWgs: number, + reps: number, + passes: number, +): Promise { + // warmup + { + const enc = device.createCommandEncoder(); + for (let p = 0; p < passes; p++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + } + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + const samples: number[] = []; + for (let r = 0; r < reps; r++) { + const enc = device.createCommandEncoder(); + for (let p = 0; p < passes; p++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + samples.push(performance.now() - t0); + } + return { + median_ms: median(samples), + min_ms: Math.min(...samples), + max_ms: Math.max(...samples), + samples_ms: samples, + }; +} + +interface RunResult { + s: number; + wgi: number; + disp: number; + pairs: number; + buckets: number; + mode: string; + T_main: number; + T_tail: number; + main_points: number; + tail_points: number; + levels: number; + marshal_ms: number; + level_ms: number[]; + tail_ms: number; + total_ms: number; + marshal_ns_per_inpt: number; + pair_tree_ns_per_inpt: number; + tail_ns_per_inpt: number; + combined_ns_per_inpt: number; + sanity_ok: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { reps: number; n: number; buckets: number; s: number; wgi: number; disp: number; mode: string } | null; + results: RunResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { + state: 'boot', + params: null, + results: [], + error: null, + log: [], +}; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const resultsClient = makeResultsClient({ page: 'bench-msm-tree' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, + params: benchState.params, + results: benchState.results, + error: benchState.error, + log: benchState.log, + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-msm-tree] ${msg}`); +} + +async function runPipeline( + device: GPUDevice, + sm: ShaderManager, + reps: number, + R: bigint, + p: bigint, +): Promise { + log('info', `=== mode=${MODE} N=${NPTS} B=${BUCKETS} S=${S} WGI=${WGI} DISP=${DISP}`); + + const rng = makeRng(0x4711); + const poolSize = NPTS + 1; // index 0 reserved (1-based) + const poolU32 = buildPointPool(poolSize, R, p, rng); + + const csr: CSR = + MODE === 'uniform' + ? buildUniformCSR(NPTS, BUCKETS, NPTS / BUCKETS) + : buildSkewedCSR(NPTS, BUCKETS, rng); + + const offsets = csr.offsets; + const counts = csr.counts; + const { chunkPlan, T, tailPlan, TT, mainPoints, tailPoints } = buildChunkAndTailPlans(offsets, counts, S); + log( + 'info', + `plan: main T=${T} chunks (${mainPoints} pts) | tail TT=${TT} threads (${tailPoints} pts) | dropped=${NPTS - mainPoints - tailPoints}`, + ); + if (T === 0 && TT === 0) throw new Error('plan is empty'); + + // Determine number of pair-tree levels. Each level halves T and + // every level produces T*S outputs. We stop when T*S equals the + // distinct-bucket count contributing to the main path (= one sum + // per bucket). Iterating further would start pairing across + // buckets, which is incorrect. + let bMain = 0; + for (let b = 0; b < counts.length; b++) { + if (counts[b] >= 2 * S) bMain++; + } + if (bMain === 0) throw new Error('no buckets in main path'); + const stopT = Math.max(1, Math.ceil(bMain / S)); + const levels: number[] = []; + for (let t = T; t >= stopT; t = Math.floor(t / 2)) { + levels.push(t); + if (t === stopT) break; + if (levels.length > 24) throw new Error('too many tree levels'); + } + log( + 'info', + `pair-tree levels: ${levels.length} (T sequence: ${levels.join(' -> ')}, stopT=${stopT}, bMain=${bMain})`, + ); + + // Buffers. + const mkSb = (size: number, copyDst: boolean, copySrc: boolean): GPUBuffer => { + let usage = GPUBufferUsage.STORAGE; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + return device.createBuffer({ size, usage }); + }; + + const poolBuf = mkSb(poolU32.byteLength, true, false); + device.queue.writeBuffer(poolBuf, 0, poolU32); + let csrBuf: GPUBuffer | null = null; + let chunkBuf: GPUBuffer | null = null; + let tailPlanBuf: GPUBuffer | null = null; + if (T > 0) { + csrBuf = mkSb(csr.csrIndices.byteLength, true, false); + device.queue.writeBuffer(csrBuf, 0, csr.csrIndices); + chunkBuf = mkSb(chunkPlan.byteLength, true, false); + device.queue.writeBuffer(chunkBuf, 0, chunkPlan); + } else if (TT > 0) { + csrBuf = mkSb(csr.csrIndices.byteLength, true, false); + device.queue.writeBuffer(csrBuf, 0, csr.csrIndices); + } + if (TT > 0) { + tailPlanBuf = mkSb(tailPlan.byteLength, true, false); + device.queue.writeBuffer(tailPlanBuf, 0, tailPlan); + } + + // Ping-pong buffers for the pair-tree. Each plane needs at most + // 2*S*T_0 vec4 (level-0 input size). Output of level k is sized + // S*T_k vec4 per plane, which is half. Use two buffers of the same + // size and ping-pong. + let bufA: GPUBuffer | null = null; + let bufB: GPUBuffer | null = null; + if (T > 0) { + const planeBytes = 2 * PG * (2 * S * T) * 4 * 4; // 2 planes (P.x, P.y) * PG vec4 * (2*S*T elems) * 4 u32/vec4 * 4 B + bufA = mkSb(planeBytes, false, true); + bufB = mkSb(planeBytes, false, true); + } + + // Bucket-sums buffer (tail output): 2 planes (P.x, P.y), PG vec4 per + // bucket. Pre-zero implicitly (GPU buffers start zeroed in WebGPU). + const bucketSumsBytes = 2 * PG * BUCKETS * 4 * 4; + const bucketSumsBuf = mkSb(bucketSumsBytes, false, true); + + const paramsBytes = 16; + const marshalParams = device.createBuffer({ size: paramsBytes, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + const levelParams: GPUBuffer[] = []; + for (let i = 0; i < levels.length; i++) { + levelParams.push(device.createBuffer({ size: paramsBytes, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST })); + } + const tailParams = device.createBuffer({ size: paramsBytes, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + + // Compile + bind. + let marshalPipeline: GPUComputePipeline | null = null; + let marshalBind: GPUBindGroup | null = null; + let marshalWgs = 0; + if (T > 0 && bufA !== null && csrBuf !== null && chunkBuf !== null) { + const code = sm.gen_ba_marshal_tree_l0_bench_shader(WGI, S); + log('info', `marshal-l0 shader: ${code.length} chars`); + const mL = marshalLayout(device); + marshalPipeline = await compile(device, code, `marshal-l0-W${WGI}-S${S}`, mL); + marshalBind = device.createBindGroup({ + layout: mL, + entries: [ + { binding: 0, resource: { buffer: csrBuf } }, + { binding: 1, resource: { buffer: chunkBuf } }, + { binding: 2, resource: { buffer: poolBuf } }, + { binding: 3, resource: { buffer: bufA } }, + { binding: 4, resource: { buffer: marshalParams } }, + ], + }); + device.queue.writeBuffer(marshalParams, 0, new Uint32Array([T, poolSize, 0, 0])); + marshalWgs = Math.ceil(T / WGI); + } + + // Tree kernel: compile once, bind per-level with the appropriate + // (input, output, params) trio. + let treePipeline: GPUComputePipeline | null = null; + const treeBinds: GPUBindGroup[] = []; + const treeNumWgs: number[] = []; + if (T > 0 && bufA !== null && bufB !== null) { + const code = sm.gen_ba_pair_disjoint_tree_bench_shader(WGI, S); + log('info', `tree-disjoint shader: ${code.length} chars`); + const tL = treeKernelLayout(device); + treePipeline = await compile(device, code, `tree-disjoint-W${WGI}-S${S}`, tL); + const dummy = device.createBuffer({ size: 16, usage: GPUBufferUsage.STORAGE }); + let curIn = bufA; + let curOut = bufB; + for (let lv = 0; lv < levels.length; lv++) { + const T_lv = levels[lv]; + const N_in_lv = 2 * S * T_lv; + const isFinal = lv === levels.length - 1 ? 1 : 0; + device.queue.writeBuffer(levelParams[lv], 0, new Uint32Array([N_in_lv, T_lv, isFinal, 0])); + treeBinds.push( + device.createBindGroup({ + layout: tL, + entries: [ + { binding: 0, resource: { buffer: curIn } }, + { binding: 1, resource: { buffer: dummy } }, + { binding: 2, resource: { buffer: curOut } }, + { binding: 3, resource: { buffer: levelParams[lv] } }, + ], + }), + ); + treeNumWgs.push(Math.ceil(T_lv / WGI)); + const swap = curIn; + curIn = curOut; + curOut = swap; + } + } + + // Tail pipeline. + let tailPipeline: GPUComputePipeline | null = null; + let tailBind: GPUBindGroup | null = null; + let tailWgs = 0; + if (TT > 0 && csrBuf !== null && tailPlanBuf !== null) { + const code = sm.gen_ba_tail_reduce_bench_shader(WGI, S); + log('info', `tail shader: ${code.length} chars`); + const tL = tailLayout(device); + tailPipeline = await compile(device, code, `tail-W${WGI}-S${S}`, tL); + tailBind = device.createBindGroup({ + layout: tL, + entries: [ + { binding: 0, resource: { buffer: csrBuf } }, + { binding: 1, resource: { buffer: tailPlanBuf } }, + { binding: 2, resource: { buffer: poolBuf } }, + { binding: 3, resource: { buffer: bucketSumsBuf } }, + { binding: 4, resource: { buffer: tailParams } }, + ], + }); + device.queue.writeBuffer(tailParams, 0, new Uint32Array([TT, poolSize, BUCKETS, 0])); + tailWgs = Math.ceil(TT / WGI); + } + + // Warmup once. + if (marshalPipeline && marshalBind) { + const enc = device.createCommandEncoder(); + const pass = enc.beginComputePass(); + pass.setPipeline(marshalPipeline); + pass.setBindGroup(0, marshalBind); + pass.dispatchWorkgroups(marshalWgs, 1, 1); + pass.end(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + if (treePipeline) { + for (let lv = 0; lv < treeBinds.length; lv++) { + const enc = device.createCommandEncoder(); + const pass = enc.beginComputePass(); + pass.setPipeline(treePipeline); + pass.setBindGroup(0, treeBinds[lv]); + pass.dispatchWorkgroups(treeNumWgs[lv], 1, 1); + pass.end(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + } + if (tailPipeline && tailBind) { + const enc = device.createCommandEncoder(); + const pass = enc.beginComputePass(); + pass.setPipeline(tailPipeline); + pass.setBindGroup(0, tailBind); + pass.dispatchWorkgroups(tailWgs, 1, 1); + pass.end(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + + // Time each stage separately, DISP back-to-back per sample. + let marshalTiming: KernelTiming | null = null; + if (marshalPipeline && marshalBind) { + marshalTiming = await timeDispatch(device, marshalPipeline, marshalBind, marshalWgs, reps, DISP); + } + const levelTimings: KernelTiming[] = []; + if (treePipeline) { + for (let lv = 0; lv < treeBinds.length; lv++) { + const t = await timeDispatch(device, treePipeline, treeBinds[lv], treeNumWgs[lv], reps, DISP); + levelTimings.push(t); + } + } + let tailTiming: KernelTiming | null = null; + if (tailPipeline && tailBind) { + tailTiming = await timeDispatch(device, tailPipeline, tailBind, tailWgs, reps, DISP); + } + + // Sanity: at least one of the output buffers must have nonzero data. + let sanity = false; + if (treePipeline && bufA && bufB) { + const finalBuf = treeBinds.length % 2 === 0 ? bufA : bufB; + sanity = sanity || (await readNonZero(device, finalBuf, 8)); + } + if (tailPipeline) { + sanity = sanity || (await readNonZero(device, bucketSumsBuf, 8)); + } + + // Compute per-stage ns/in-pt (normalised to total points fed to that + // stage; DISP-amortised wall clock). + const totalInPts = mainPoints + tailPoints; + const marshalMed = marshalTiming?.median_ms ?? 0; + const marshalNs = (marshalMed * 1e6) / (mainPoints * DISP); + const treeTotalMed = levelTimings.reduce((acc, t) => acc + t.median_ms, 0); + const treeNs = (treeTotalMed * 1e6) / (mainPoints * DISP); + const tailMed = tailTiming?.median_ms ?? 0; + const tailNs = TT > 0 ? (tailMed * 1e6) / (tailPoints * DISP) : 0; + const combinedTotal = marshalMed + treeTotalMed + tailMed; + const combinedNs = (combinedTotal * 1e6) / (totalInPts * DISP); + + log( + sanity ? 'ok' : 'err', + `marshal=${marshalNs.toFixed(2)}ns/pt pair_tree=${treeNs.toFixed(2)}ns/pt tail=${tailNs.toFixed(2)}ns/pt | combined=${combinedNs.toFixed(2)}ns/in-pt | sanity=${sanity ? 'OK' : 'FAIL'}`, + ); + for (let lv = 0; lv < levelTimings.length; lv++) { + log('info', ` level ${lv} T=${levels[lv]}: median=${levelTimings[lv].median_ms.toFixed(3)}ms min=${levelTimings[lv].min_ms.toFixed(3)}ms`); + } + + // Cleanup + poolBuf.destroy(); + csrBuf?.destroy(); + chunkBuf?.destroy(); + tailPlanBuf?.destroy(); + bufA?.destroy(); + bufB?.destroy(); + bucketSumsBuf.destroy(); + marshalParams.destroy(); + for (const b of levelParams) b.destroy(); + tailParams.destroy(); + + return { + s: S, + wgi: WGI, + disp: DISP, + pairs: NPTS, + buckets: BUCKETS, + mode: MODE, + T_main: T, + T_tail: TT, + main_points: mainPoints, + tail_points: tailPoints, + levels: levelTimings.length, + marshal_ms: marshalMed, + level_ms: levelTimings.map(t => t.median_ms), + tail_ms: tailMed, + total_ms: combinedTotal, + marshal_ns_per_inpt: marshalNs, + pair_tree_ns_per_inpt: treeNs, + tail_ns_per_inpt: tailNs, + combined_ns_per_inpt: combinedNs, + sanity_ok: sanity, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + const reps = parseInt(qp.get('reps') ?? '5', 10); + if (!Number.isFinite(reps) || reps <= 0 || reps > 50) throw new Error(`?reps must be in (0, 50]`); + if (qp.get('n')) { + const v = parseInt(qp.get('n')!, 10); + if (!Number.isFinite(v) || v <= 0 || v > (1 << 20)) throw new Error(`?n must be in (0, 2^20]`); + NPTS = v; + } + if (qp.get('buckets')) { + const v = parseInt(qp.get('buckets')!, 10); + if (!Number.isFinite(v) || v <= 0 || v > (1 << 18)) throw new Error(`?buckets must be in (0, 2^18]`); + BUCKETS = v; + } + if (qp.get('s')) { + const v = parseInt(qp.get('s')!, 10); + if (!Number.isFinite(v) || v <= 0 || v > 256) throw new Error(`?s must be in (0, 256]`); + S = v; + } + if (qp.get('wgi')) { + const v = parseInt(qp.get('wgi')!, 10); + if (!Number.isFinite(v) || v <= 0 || v > 1024) throw new Error(`?wgi must be in (0, 1024]`); + WGI = v; + } + if (qp.get('disp')) { + const v = parseInt(qp.get('disp')!, 10); + if (!Number.isFinite(v) || v <= 0 || v > 64) throw new Error(`?disp must be in (0, 64]`); + DISP = v; + } + if (qp.get('mode')) { + const v = qp.get('mode')!; + if (v !== 'uniform' && v !== 'skewed') throw new Error(`?mode must be uniform or skewed`); + MODE = v; + } + return { reps, n: NPTS, buckets: BUCKETS, s: S, wgi: WGI, disp: DISP, mode: MODE }; +} + +async function main() { + try { + if (!('gpu' in navigator)) throw new Error('navigator.gpu missing — WebGPU not available'); + const params = parseParams(); + benchState.params = params; + log( + 'info', + `params: reps=${params.reps} n=${params.n} buckets=${params.buckets} s=${params.s} wgi=${params.wgi} disp=${params.disp} mode=${params.mode}`, + ); + + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, 13); + const R = miscParams.r; + + const sm = new ShaderManager(4, NPTS, BN254_CURVE_CONFIG, false); + + const r = await runPipeline(device, sm, params.reps, R, p); + benchState.results.push(r); + resultsClient.postProgress({ + kind: 'pipeline_done', + mode: r.mode, + combined_ns_per_inpt: r.combined_ns_per_inpt, + sanity_ok: r.sanity_ok, + }); + + benchState.state = 'done'; + log('ok', 'pipeline complete'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + }) + .finally(() => { + postFinal().catch(() => {}); + }); diff --git a/barretenberg/ts/dev/msm-webgpu/bench-planner.html b/barretenberg/ts/dev/msm-webgpu/bench-planner.html new file mode 100644 index 000000000000..dd773cd4d504 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-planner.html @@ -0,0 +1,22 @@ + + + + + Standalone GPU planner microbench (WebGPU) + + + +

Standalone GPU planner microbench

+

Query params: ?buckets=B&lambda=L&s=S&tpb=T&per=P&disp=D&reps=R&validate=1

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-planner.ts b/barretenberg/ts/dev/msm-webgpu/bench-planner.ts new file mode 100644 index 000000000000..210da3c73b83 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-planner.ts @@ -0,0 +1,423 @@ +/// +// Standalone microbench for the GPU bin-packing planner kernel. +// Isolates the planner from the rest of the MSM pipeline so we can +// pin down the minimum time required to build a per-level plan +// (chunk_plan + scatter_plan + carry_plan + new_counts/offsets) on +// the GPU. +// +// Inputs (synthetic, host-built upfront): +// counts[B] per-bucket active count drawn from Poisson(lambda). +// offsets[B+1] prefix sum of counts. +// +// Each planner dispatch is one workgroup of TPB threads. Each thread +// handles PER_THREAD buckets. B = TPB * PER_THREAD. +// +// Timing methodology: +// - Compile pipeline. +// - Warmup (1 dispatch). +// - For each rep: +// Encode DISP back-to-back dispatches in ONE command encoder. +// performance.now() right before submit and after await. +// - Per-planner time = sample / DISP. +// - Report min, median, max across reps. +// +// Validation (?validate=1): +// - Run 1 dispatch. +// - Read back chunk_plan, scatter_plan, carry_plan, new_counts, +// new_offsets, totals. +// - Cross-check against a host-side bin-pack reference. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { makeResultsClient } from './results_post.js'; + +let BUCKETS = 4096; +let LAMBDA = 32; // mean per-bucket count +let S = 16; +let TPB = 256; +let PER_THREAD = 16; // BUCKETS / TPB +let DISP = 128; +let REPS = 5; +let VALIDATE = false; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +// Approximate Poisson(λ) sample via the Knuth method. Adequate for +// generating synthetic bucket-count distributions in [0, ~3λ]. +function poisson(lambda: number, rng: () => number): number { + const L = Math.exp(-lambda); + let k = 0; + let p = 1.0; + for (;;) { + k++; + const u = (rng() >>> 0) / 0x100000000; + p *= u; + if (p <= L) break; + if (k > 200) break; + } + return k - 1; +} + +function buildSyntheticCounts(B: number, lambda: number, seed: number): { counts: Uint32Array; offsets: Uint32Array } { + const rng = makeRng(seed); + const counts = new Uint32Array(B); + for (let b = 0; b < B; b++) counts[b] = poisson(lambda, rng); + const offsets = new Uint32Array(B + 1); + for (let b = 0; b < B; b++) offsets[b + 1] = offsets[b] + counts[b]; + return { counts, offsets }; +} + +// Host-side bin-pack reference. Returns the EXACT same outputs the GPU +// planner is expected to produce (modulo per-bucket atomic-ordering +// differences, which the v2 planner avoids — its order matches host). +function buildHostReference(counts: Uint32Array, offsets: Uint32Array, S: number) { + const B = counts.length; + let totalPairs = 0; + let totalCarries = 0; + let totalNew = 0; + const newCounts = new Uint32Array(B); + const newOffsets = new Uint32Array(B + 1); + // First pass: compute new_counts and accumulate totals. + for (let b = 0; b < B; b++) { + const n = counts[b]; + const pc = Math.floor(n / 2); + const cf = n & 1; + newCounts[b] = pc + cf; + totalPairs += pc; + totalCarries += cf; + totalNew += pc + cf; + } + for (let b = 0; b < B; b++) newOffsets[b + 1] = newOffsets[b] + newCounts[b]; + const numChunks = Math.max(1, Math.ceil(totalPairs / S)); + const chunkPlan = new Uint32Array(2 * numChunks * S); + const scatterPlan = new Uint32Array(numChunks * S); + const carryPlan = new Uint32Array(2 * Math.max(1, totalCarries)); + let pairOff = 0; + let carryOff = 0; + for (let b = 0; b < B; b++) { + const n = counts[b]; + const pc = Math.floor(n / 2); + const cf = n & 1; + const bucketBase = offsets[b]; + for (let j = 0; j < pc; j++) { + const slot = pairOff + j; + const chunkId = Math.floor(slot / S); + const slotInChunk = slot % S; + const cpBase = 2 * (chunkId * S + slotInChunk); + chunkPlan[cpBase + 0] = bucketBase + 2 * j; + chunkPlan[cpBase + 1] = bucketBase + 2 * j + 1; + scatterPlan[chunkId * S + slotInChunk] = newOffsets[b] + j; + } + if (cf) { + carryPlan[2 * carryOff + 0] = bucketBase + n - 1; + carryPlan[2 * carryOff + 1] = newOffsets[b] + pc; + carryOff++; + } + pairOff += pc; + } + return { chunkPlan, scatterPlan, carryPlan, newCounts, newOffsets, totalPairs, totalCarries, totalNew, numChunks }; +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +interface BenchResult { + buckets: number; + lambda: number; + s: number; + tpb: number; + per_thread: number; + disp: number; + reps: number; + total_pairs: number; + total_carries: number; + num_chunks: number; + per_dispatch_us: { min: number; median: number; max: number }; + wall_samples_ms: number[]; + validated: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: Record | null; + results: BenchResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { state: 'boot', params: null, results: [], error: null, log: [] }; +(window as unknown as { __bench: BenchState }).__bench = benchState; +const resultsClient = makeResultsClient({ page: 'bench-planner' }); +(window as unknown as { __runId: string }).__runId = resultsClient.runId; + +async function postFinal(): Promise { + await resultsClient.postResults({ + state: benchState.state, params: benchState.params, results: benchState.results, + error: benchState.error, log: benchState.log, + userAgent: navigator.userAgent, hardwareConcurrency: navigator.hardwareConcurrency, + }); +} + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-planner] ${msg}`); +} + +async function compileOne(device: GPUDevice, code: string, key: string, layout: GPUBindGroupLayout): Promise { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + const errLines: string[] = []; + for (const m of info.messages) { + const line = `[shader ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`; + if (m.type === 'error') { console.error(line); log('err', line); errLines.push(line); hasError = true; } + else { console.warn(line); } + } + if (hasError) throw new Error(`WGSL compile failed for ${key}: ${errLines.slice(0, 4).join(' | ')}`); + return device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); +} + +async function readbackU32(device: GPUDevice, buf: GPUBuffer, byteLength: number): Promise { + const staging = device.createBuffer({ size: byteLength, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const enc = device.createCommandEncoder(); + enc.copyBufferToBuffer(buf, 0, staging, 0, byteLength); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await staging.mapAsync(GPUMapMode.READ); + const out = new Uint32Array(staging.getMappedRange().slice(0)); + staging.unmap(); + staging.destroy(); + return out; +} + +async function runOne(device: GPUDevice, sm: ShaderManager): Promise { + log('info', `=== B=${BUCKETS} λ=${LAMBDA} S=${S} TPB=${TPB} PER=${PER_THREAD} DISP=${DISP} REPS=${REPS}`); + if (TPB * PER_THREAD !== BUCKETS) throw new Error(`BUCKETS=${BUCKETS} must equal TPB*PER_THREAD=${TPB * PER_THREAD}`); + + const { counts, offsets } = buildSyntheticCounts(BUCKETS, LAMBDA, 0x5fa11); + let totalActive = 0; + let cMin = 99999, cMax = 0; + for (let b = 0; b < BUCKETS; b++) { + totalActive += counts[b]; + if (counts[b] > cMax) cMax = counts[b]; + if (counts[b] < cMin) cMin = counts[b]; + } + log('info', `synthetic counts: min=${cMin} max=${cMax} totalActive=${totalActive}`); + + const ref = buildHostReference(counts, offsets, S); + log('info', `host reference: totalPairs=${ref.totalPairs} totalCarries=${ref.totalCarries} numChunks=${ref.numChunks}`); + + // Allocate output buffers sized for the host-computed plan. + // For a real MSM these sizes would be conservative max bounds; here + // we use exact host values for tighter validation. + const mkStorage = (bytes: number, copySrc = false, copyDst = false): GPUBuffer => { + let usage = GPUBufferUsage.STORAGE; + if (copySrc) usage |= GPUBufferUsage.COPY_SRC; + if (copyDst) usage |= GPUBufferUsage.COPY_DST; + return device.createBuffer({ size: bytes, usage }); + }; + + const countsBuf = mkStorage(counts.byteLength, false, true); + const offsetsBuf = mkStorage(offsets.byteLength, false, true); + device.queue.writeBuffer(countsBuf, 0, counts); + device.queue.writeBuffer(offsetsBuf, 0, offsets); + + const chunkPlanBytes = ref.chunkPlan.byteLength; + const scatterPlanBytes = ref.scatterPlan.byteLength; + const carryPlanBytes = ref.carryPlan.byteLength; + const newCountsBytes = ref.newCounts.byteLength; + const newOffsetsBytes = ref.newOffsets.byteLength; + const totalsBytes = 16; + + const chunkPlanBuf = mkStorage(chunkPlanBytes, true); + const scatterPlanBuf = mkStorage(scatterPlanBytes, true); + const carryPlanBuf = mkStorage(carryPlanBytes, true); + const newCountsBuf = mkStorage(newCountsBytes, true); + const newOffsetsBuf = mkStorage(newOffsetsBytes, true); + const totalsBuf = mkStorage(totalsBytes, true); + + const paramsBuf = device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(paramsBuf, 0, new Uint32Array([BUCKETS, S, 0, 0])); + + const layout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 5, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 6, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 7, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 8, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const pipeline = await compileOne(device, sm.gen_ba_planner_v2_bench_shader(TPB, PER_THREAD, S, 64), `planner-v2-T${TPB}-P${PER_THREAD}-S${S}`, layout); + const bind = device.createBindGroup({ + layout, + entries: [ + { binding: 0, resource: { buffer: countsBuf } }, + { binding: 1, resource: { buffer: offsetsBuf } }, + { binding: 2, resource: { buffer: chunkPlanBuf } }, + { binding: 3, resource: { buffer: scatterPlanBuf } }, + { binding: 4, resource: { buffer: carryPlanBuf } }, + { binding: 5, resource: { buffer: newCountsBuf } }, + { binding: 6, resource: { buffer: newOffsetsBuf } }, + { binding: 7, resource: { buffer: totalsBuf } }, + { binding: 8, resource: { buffer: paramsBuf } }, + ], + }); + + // Warmup. + { + const enc = device.createCommandEncoder(); + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(1, 1, 1); + pass.end(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + } + log('info', 'warmup done'); + + // Validation (one dispatch, read back, compare against host reference). + let validated = false; + if (VALIDATE) { + const gpuTotals = await readbackU32(device, totalsBuf, totalsBytes); + const gpuNewCounts = await readbackU32(device, newCountsBuf, newCountsBytes); + const gpuNewOffsets = await readbackU32(device, newOffsetsBuf, newOffsetsBytes); + const gpuChunkPlan = await readbackU32(device, chunkPlanBuf, chunkPlanBytes); + const gpuScatterPlan = await readbackU32(device, scatterPlanBuf, scatterPlanBytes); + const gpuCarryPlan = await readbackU32(device, carryPlanBuf, carryPlanBytes); + + const mismatches: string[] = []; + if (gpuTotals[0] !== ref.totalPairs) mismatches.push(`totals[0]: gpu=${gpuTotals[0]} ref=${ref.totalPairs}`); + if (gpuTotals[1] !== ref.totalCarries) mismatches.push(`totals[1]: gpu=${gpuTotals[1]} ref=${ref.totalCarries}`); + if (gpuTotals[2] !== ref.totalNew) mismatches.push(`totals[2]: gpu=${gpuTotals[2]} ref=${ref.totalNew}`); + for (let b = 0; b < BUCKETS && mismatches.length < 8; b++) { + if (gpuNewCounts[b] !== ref.newCounts[b]) mismatches.push(`newCounts[${b}]: gpu=${gpuNewCounts[b]} ref=${ref.newCounts[b]}`); + if (gpuNewOffsets[b] !== ref.newOffsets[b]) mismatches.push(`newOffsets[${b}]: gpu=${gpuNewOffsets[b]} ref=${ref.newOffsets[b]}`); + } + // chunk_plan/scatter_plan: compare element-wise. + let cpFails = 0; + for (let i = 0; i < ref.chunkPlan.length; i++) { + if (gpuChunkPlan[i] !== ref.chunkPlan[i]) { cpFails++; if (cpFails <= 3) mismatches.push(`chunkPlan[${i}]: gpu=${gpuChunkPlan[i]} ref=${ref.chunkPlan[i]}`); } + } + let spFails = 0; + for (let i = 0; i < ref.scatterPlan.length; i++) { + if (gpuScatterPlan[i] !== ref.scatterPlan[i]) { spFails++; if (spFails <= 3) mismatches.push(`scatterPlan[${i}]: gpu=${gpuScatterPlan[i]} ref=${ref.scatterPlan[i]}`); } + } + let cyFails = 0; + for (let i = 0; i < 2 * ref.totalCarries; i++) { + if (gpuCarryPlan[i] !== ref.carryPlan[i]) { cyFails++; if (cyFails <= 3) mismatches.push(`carryPlan[${i}]: gpu=${gpuCarryPlan[i]} ref=${ref.carryPlan[i]}`); } + } + if (mismatches.length === 0 && cpFails === 0 && spFails === 0 && cyFails === 0) { + validated = true; + log('ok', 'validation: PASS — GPU planner output byte-equivalent to host reference'); + } else { + log('err', `validation: FAIL — ${cpFails} chunkPlan, ${spFails} scatterPlan, ${cyFails} carryPlan mismatches; first few:`); + for (const m of mismatches.slice(0, 10)) log('err', ` ${m}`); + } + } + + // Timed runs: DISP back-to-back planner dispatches in one command encoder. + const samples: number[] = []; + for (let r = 0; r < REPS; r++) { + const enc = device.createCommandEncoder(); + for (let d = 0; d < DISP; d++) { + const pass = enc.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(1, 1, 1); + pass.end(); + } + const t0 = performance.now(); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + samples.push(performance.now() - t0); + } + const med = median(samples); + const mn = Math.min(...samples); + const mx = Math.max(...samples); + const perDispatchMin = (mn / DISP) * 1000; + const perDispatchMed = (med / DISP) * 1000; + const perDispatchMax = (mx / DISP) * 1000; + + log( + 'ok', + `per-planner: min=${perDispatchMin.toFixed(2)}μs median=${perDispatchMed.toFixed(2)}μs max=${perDispatchMax.toFixed(2)}μs` + + ` (total wall: min=${mn.toFixed(2)}ms median=${med.toFixed(2)}ms max=${mx.toFixed(2)}ms over DISP=${DISP})`, + ); + + countsBuf.destroy(); offsetsBuf.destroy(); chunkPlanBuf.destroy(); scatterPlanBuf.destroy(); + carryPlanBuf.destroy(); newCountsBuf.destroy(); newOffsetsBuf.destroy(); totalsBuf.destroy(); + paramsBuf.destroy(); + + return { + buckets: BUCKETS, lambda: LAMBDA, s: S, tpb: TPB, per_thread: PER_THREAD, disp: DISP, reps: REPS, + total_pairs: ref.totalPairs, total_carries: ref.totalCarries, num_chunks: ref.numChunks, + per_dispatch_us: { min: perDispatchMin, median: perDispatchMed, max: perDispatchMax }, + wall_samples_ms: samples, + validated, + }; +} + +function parseParams() { + const qp = new URLSearchParams(window.location.search); + if (qp.get('buckets')) BUCKETS = parseInt(qp.get('buckets')!, 10); + if (qp.get('lambda')) LAMBDA = parseInt(qp.get('lambda')!, 10); + if (qp.get('s')) S = parseInt(qp.get('s')!, 10); + if (qp.get('tpb')) TPB = parseInt(qp.get('tpb')!, 10); + if (qp.get('per')) PER_THREAD = parseInt(qp.get('per')!, 10); + if (qp.get('disp')) DISP = parseInt(qp.get('disp')!, 10); + if (qp.get('reps')) REPS = parseInt(qp.get('reps')!, 10); + if (qp.get('validate') === '1') VALIDATE = true; + return { buckets: BUCKETS, lambda: LAMBDA, s: S, tpb: TPB, per: PER_THREAD, disp: DISP, reps: REPS, validate: VALIDATE }; +} + +async function main() { + try { + if (!('gpu' in navigator)) throw new Error('navigator.gpu missing'); + const params = parseParams(); + benchState.params = params; + log('info', `params: ${JSON.stringify(params)}`); + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + const sm = new ShaderManager(4, BUCKETS, BN254_CURVE_CONFIG, false); + const r = await runOne(device, sm); + benchState.results.push(r); + resultsClient.postProgress({ kind: 'planner_done', per_dispatch_us: r.per_dispatch_us, validated: r.validated }); + benchState.state = 'done'; + log('ok', 'done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main() + .catch(e => { const msg = e instanceof Error ? e.message : String(e); log('err', `unhandled: ${msg}`); benchState.state = 'error'; benchState.error = msg; }) + .finally(() => { postFinal().catch(() => {}); }); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs index a9e3eacd368d..4bb7c19a9751 100644 --- a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs +++ b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs @@ -127,6 +127,17 @@ if (!TARGETS[argv.target]) { const pageMap = { "bench-batch-affine": "/dev/msm-webgpu/bench-batch-affine.html", + "bench-fused-wg-scan": "/dev/msm-webgpu/bench-fused-wg-scan.html", + "bench-ba-rev-packed-carry": "/dev/msm-webgpu/bench-ba-rev-packed-carry.html", + "bench-msm-chain": "/dev/msm-webgpu/bench-msm-chain.html", + "bench-ba-pair-disjoint": "/dev/msm-webgpu/bench-ba-pair-disjoint.html", + "bench-msm-tree": "/dev/msm-webgpu/bench-msm-tree.html", + "bench-msm-tree-v2": "/dev/msm-webgpu/bench-msm-tree-v2.html", + "bench-msm-tree-v3": "/dev/msm-webgpu/bench-msm-tree-v3.html", + "bench-planner": "/dev/msm-webgpu/bench-planner.html", + "bench-csr-to-v2": "/dev/msm-webgpu/bench-csr-to-v2.html", + "bench-msm-oracle": "/dev/msm-webgpu/bench-msm-oracle.html", + "bench-msm-oracle-prod": "/dev/msm-webgpu/bench-msm-oracle-prod.html", "bench-smvp-tree": "/dev/msm-webgpu/bench-smvp-tree.html", sanity: "/dev/msm-webgpu/index.html", }; diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index 6c97ef584582..24bd0a2f4ba1 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -5,11 +5,29 @@ import { batch_affine_apply as batch_affine_apply_shader, batch_affine_apply_scatter as batch_affine_apply_scatter_shader, batch_affine_dispatch_args as batch_affine_dispatch_args_shader, + ba_carry_copy_bench as ba_carry_copy_bench_shader, + ba_fused_super_bench as ba_fused_super_bench_shader, + ba_marshal_chain_bench as ba_marshal_chain_bench_shader, + ba_marshal_pairs_bench as ba_marshal_pairs_bench_shader, + ba_marshal_tree_l0_bench as ba_marshal_tree_l0_bench_shader, + ba_pair_disjoint_bench as ba_pair_disjoint_bench_shader, + ba_pair_disjoint_tree_bench as ba_pair_disjoint_tree_bench_shader, + ba_planner_bench as ba_planner_bench_shader, + ba_planner_v2_bench as ba_planner_v2_bench_shader, + ba_planner_v2_prod as ba_planner_v2_prod_shader, + ba_marshal_pairs_prod as ba_marshal_pairs_prod_shader, + ba_pair_disjoint_tree_prod as ba_pair_disjoint_tree_prod_shader, + ba_scatter_pairs_prod as ba_scatter_pairs_prod_shader, + ba_carry_copy_prod as ba_carry_copy_prod_shader, + ba_scatter_pairs_bench as ba_scatter_pairs_bench_shader, + ba_tail_reduce_bench as ba_tail_reduce_bench_shader, + ba_rev_packed_carry_bench as ba_rev_packed_carry_bench_shader, bench_batch_affine as bench_batch_affine_shader, batch_affine_finalize as batch_affine_finalize_shader, batch_affine_finalize_apply as batch_affine_finalize_apply_shader, batch_affine_finalize_collect as batch_affine_finalize_collect_shader, batch_affine_fused_revcarry as batch_affine_fused_revcarry_shader, + batch_affine_fused_wg_scan as batch_affine_fused_wg_scan_shader, batch_affine_init as batch_affine_init_shader, batch_affine_schedule as batch_affine_schedule_shader, batch_inverse as batch_inverse_shader, @@ -27,6 +45,9 @@ import { bpr_bn254 as bpr_bn254_shader, convert_point_coords_and_decompose_scalars, convert_points_only as convert_points_only_shader, + csr_to_v2_active_sums as csr_to_v2_active_sums_shader, + csr_to_v2_meta as csr_to_v2_meta_shader, + v2_to_running as v2_to_running_shader, decompose_scalars_signed_only as decompose_scalars_signed_only_shader, decompress_g1_bn254 as decompress_g1_bn254_shader, divsteps_bench as divsteps_bench_shader, @@ -42,6 +63,7 @@ import { mont_pro_product_f32_22_sos3uv3 as montgomery_product_f32_22_sos3uv3_funcs, mont_pro_product_karat_yuval as montgomery_product_karat_yuval_funcs, mulhilo_22 as mulhilo_22_funcs, + packed_field as packed_field_funcs, smvp_bn254 as smvp_bn254_shader, smvp_tree_entry_bucket_id as smvp_tree_entry_bucket_id_shader, smvp_tree_phase1 as smvp_tree_phase1_shader, @@ -762,6 +784,491 @@ ${packLines.join('\n')} ); } + /** + * Marshal kernel for the bench-msm-chain pipeline: transposes a + * CSR-sorted point list into the strided SoA layout the + * ba_rev_packed_carry chain kernel consumes. Pure memory shuffle. + */ + public gen_ba_marshal_chain_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_marshal_chain_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + return mustache.render( + ba_marshal_chain_bench_shader, + { workgroup_size, s, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Fused super-kernel for the v3 pipeline: combines marshal + disjoint + * + scatter into one kernel. Per chunk-thread: reads chunk_plan + + * scatter_plan, gathers from active_sums_old, computes batched- + * inverse pair sums in registers, writes directly to active_sums_new. + * Carry is a separate kernel. + */ + public gen_ba_fused_super_bench_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_fused_super_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + const dec = this.decoupledPackUnpackWgsl(); + return mustache.render( + ba_fused_super_bench_shader, + { + workgroup_size, s, + word_size: this.word_size, num_words: this.num_words, n0: this.n0, + p_limbs: this.p_limbs, r_limbs: this.r_limbs, r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, mask: this.mask, + two_pow_word_size: this.two_pow_word_size, p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, + dec_unpack: dec.unpack, dec_pack: dec.pack, recompile: this.recompile, + }, + { + structs, bigint_funcs, + montgomery_product_funcs: this.mont_product_src, + field_funcs, fr_pow_funcs, bigint_by_funcs, by_inverse_a_funcs, + }, + ); + } + + /** + * v2 GPU planner: single-kernel scan + scatter. One workgroup of TPB + * threads handles all B buckets via per-thread local scan + workgroup- + * wide Hillis-Steele scan + per-thread scatter. No atomics, no host + * sync, single dispatch. Scales to B <= TPB * PER_THREAD within one + * workgroup (e.g. 256 * 32 = 8192 buckets). + */ + public gen_ba_planner_v2_bench_shader(workgroup_size: number, per_thread: number, s: number, pair_cap: number = 64): string { + if (workgroup_size <= 0 || per_thread <= 0 || s <= 0 || pair_cap <= 0 || + !Number.isInteger(workgroup_size) || !Number.isInteger(per_thread) || !Number.isInteger(s) || !Number.isInteger(pair_cap)) { + throw new Error(`gen_ba_planner_v2_bench_shader: positive integer args required`); + } + return mustache.render( + ba_planner_v2_bench_shader, + { workgroup_size, per_thread, pair_cap, s, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * GPU-side bin-packing planner for the v3 pipeline. One thread per + * bucket; atomicAdd reserves global per-pair / per-carry / per-new- + * slot offsets; the thread writes chunk_plan + scatter_plan + + * carry_plan + new_counts + new_offsets for its bucket. Pair-count + * loop bounded by compile-time `pair_cap` (defaults to 64 — covers + * Poisson(λ=32) max-count without issue). + */ + public gen_ba_planner_bench_shader(workgroup_size: number, s: number, pair_cap: number = 64): string { + if (workgroup_size <= 0 || s <= 0 || pair_cap <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s) || !Number.isInteger(pair_cap)) { + throw new Error(`gen_ba_planner_bench_shader: workgroup_size (${workgroup_size}), s (${s}), and pair_cap (${pair_cap}) must be positive integers`); + } + return mustache.render( + ba_planner_bench_shader, + { workgroup_size, s, pair_cap, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Bin-packed pair-tree: marshal kernel that gathers operands from a + * generic active_sums buffer per chunk_plan. Works at any level + * (L0 active_sums = bucket-sorted point pool, L1+ = previous + * level's pair-sums + carries). + */ + public gen_ba_marshal_pairs_bench_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_marshal_pairs_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + return mustache.render( + ba_marshal_pairs_bench_shader, + { workgroup_size, s, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Bin-packed pair-tree: scatter kernel that places the disjoint + * kernel's strided outputs at per-bucket destinations in + * active_sums_new per scatter_plan. + */ + public gen_ba_scatter_pairs_bench_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_scatter_pairs_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + return mustache.render( + ba_scatter_pairs_bench_shader, + { workgroup_size, s, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Layout converter (active_sums materialization): copies packed + * 8×u32 base coords from cached_bases into bucket-major active_sums + * indexed by val_idx. One thread per (subtask, slot). Pure raw vec4 + * copy — no field-element math. + */ + public gen_csr_to_v2_active_sums_shader(workgroup_size: number): string { + if (workgroup_size <= 0 || !Number.isInteger(workgroup_size)) { + throw new Error(`gen_csr_to_v2_active_sums_shader: workgroup_size (${workgroup_size}) must be a positive integer`); + } + return mustache.render( + csr_to_v2_active_sums_shader, + { workgroup_size, recompile: this.recompile }, + {}, + ); + } + + /** + * Production v2 GPU planner. Same algorithm as + * gen_ba_planner_v2_bench_shader, additionally emits per-level + * dispatch_args triples into totals[4..6] (marshal/disjoint/scatter) + * and totals[7..9] (carry) so the host orchestrator can drive the + * four downstream prod kernels via dispatchWorkgroupsIndirect with + * zero pad-chunk waste. wgi must match the workgroup size used to + * compile the four downstream prod kernels. + */ + public gen_ba_planner_v2_prod_shader(workgroup_size: number, per_thread: number, s: number, wgi: number, pair_cap: number = 64): string { + if (workgroup_size <= 0 || per_thread <= 0 || s <= 0 || wgi <= 0 || pair_cap <= 0 || + !Number.isInteger(workgroup_size) || !Number.isInteger(per_thread) || !Number.isInteger(s) || !Number.isInteger(wgi) || !Number.isInteger(pair_cap)) { + throw new Error(`gen_ba_planner_v2_prod_shader: positive integer args required`); + } + return mustache.render( + ba_planner_v2_prod_shader, + { workgroup_size, per_thread, pair_cap, s, wgi, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Marshal pairs — prod variant. Reads num_chunks from + * totals[3] (storage), dispatched indirectly off totals[4..6]. + */ + public gen_ba_marshal_pairs_prod_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_marshal_pairs_prod_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + return mustache.render( + ba_marshal_pairs_prod_shader, + { workgroup_size, s, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Disjoint pair-sum tree — prod variant. Reads T_curr from + * totals[3] (storage), always uses the final-mode strided write. + */ + public gen_ba_pair_disjoint_tree_prod_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_pair_disjoint_tree_prod_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + const dec = this.decoupledPackUnpackWgsl(); + return mustache.render( + ba_pair_disjoint_tree_prod_shader, + { + workgroup_size, + s, + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, + dec_unpack: dec.unpack, + dec_pack: dec.pack, + recompile: this.recompile, + }, + { + structs, + bigint_funcs, + montgomery_product_funcs: this.mont_product_src, + field_funcs, + fr_pow_funcs, + bigint_by_funcs, + by_inverse_a_funcs, + }, + ); + } + + /** + * Scatter pairs — prod variant. Reads T from totals[3] (storage). + */ + public gen_ba_scatter_pairs_prod_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_scatter_pairs_prod_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + return mustache.render( + ba_scatter_pairs_prod_shader, + { workgroup_size, s, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Carry copy — prod variant. Reads num_carries from totals[1] (storage). + */ + public gen_ba_carry_copy_prod_shader(workgroup_size: number): string { + if (workgroup_size <= 0 || !Number.isInteger(workgroup_size)) { + throw new Error(`gen_ba_carry_copy_prod_shader: workgroup_size (${workgroup_size}) must be a positive integer`); + } + return mustache.render( + ba_carry_copy_prod_shader, + { workgroup_size, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Boundary adapter for the v2 pair-tree -> production finalize: copies + * the per-bucket reduced packed point out of the v2 combined-SoA + * active_sums buffer into the production running_x / running_y layout + * and writes bucket_active. One thread per (subtask, bucket_local); + * the caller binds running_x / running_y / bucket_active sub-views + * offset by subtask_idx * num_columns. + */ + public gen_v2_to_running_shader(workgroup_size: number): string { + if (workgroup_size <= 0 || !Number.isInteger(workgroup_size)) { + throw new Error(`gen_v2_to_running_shader: workgroup_size (${workgroup_size}) must be a positive integer`); + } + return mustache.render( + v2_to_running_shader, + { workgroup_size, recompile: this.recompile }, + {}, + ); + } + + /** + * Layout converter (meta derivation): writes per-bucket count and + * subtask-relative offset from cuZK row_ptr. One thread per + * (subtask, bucket). + */ + public gen_csr_to_v2_meta_shader(workgroup_size: number): string { + if (workgroup_size <= 0 || !Number.isInteger(workgroup_size)) { + throw new Error(`gen_csr_to_v2_meta_shader: workgroup_size (${workgroup_size}) must be a positive integer`); + } + return mustache.render( + csr_to_v2_meta_shader, + { workgroup_size, recompile: this.recompile }, + {}, + ); + } + + /** + * Bin-packed pair-tree: carry-copy kernel. Propagates the odd-count + * carry element forward to the next level without modification. + * Pure memory shuffle. + */ + public gen_ba_carry_copy_bench_shader(workgroup_size: number): string { + if (workgroup_size <= 0 || !Number.isInteger(workgroup_size)) { + throw new Error(`gen_ba_carry_copy_bench_shader: workgroup_size (${workgroup_size}) must be a positive integer`); + } + return mustache.render( + ba_carry_copy_bench_shader, + { workgroup_size, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Tree variant of the disjoint pair-sum kernel: writes outputs in the + * layout the next pair-tree level expects as input, so multi-level + * reductions can chain without an intervening marshal pass. Per + * thread: 2*S inputs -> S disjoint pair sums. Final-level flag (via + * params.z) switches to a simple strided write for the last pass. + */ + public gen_ba_pair_disjoint_tree_bench_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_pair_disjoint_tree_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + const dec = this.decoupledPackUnpackWgsl(); + return mustache.render( + ba_pair_disjoint_tree_bench_shader, + { + workgroup_size, + s, + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, + dec_unpack: dec.unpack, + dec_pack: dec.pack, + recompile: this.recompile, + }, + { + structs, + bigint_funcs, + montgomery_product_funcs: this.mont_product_src, + field_funcs, + fr_pow_funcs, + bigint_by_funcs, + by_inverse_a_funcs, + }, + ); + } + + /** + * Level-0 marshal kernel for the bench-msm-tree pipeline: transposes + * CSR-sorted point indices into the 2-plane strided SoA layout the + * tree-disjoint kernel reads. Pure memory shuffle. + */ + public gen_ba_marshal_tree_l0_bench_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_marshal_tree_l0_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + return mustache.render( + ba_marshal_tree_l0_bench_shader, + { workgroup_size, s, num_words: this.num_words, recompile: this.recompile }, + { structs }, + ); + } + + /** + * Tail kernel for the bench-msm-tree pipeline: one thread per small + * bucket (count < 2*S), serial per-thread chain with one fr_inv_by_a + * per add (no batched inversion). Bounded loop up to compile-time + * TAIL_CAP = 2*S - 1. + */ + public gen_ba_tail_reduce_bench_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_tail_reduce_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + const tail_cap = 2 * s - 1; + const dec = this.decoupledPackUnpackWgsl(); + return mustache.render( + ba_tail_reduce_bench_shader, + { + workgroup_size, + tail_cap, + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, + dec_unpack: dec.unpack, + dec_pack: dec.pack, + recompile: this.recompile, + }, + { + structs, + bigint_funcs, + montgomery_product_funcs: this.mont_product_src, + field_funcs, + fr_pow_funcs, + bigint_by_funcs, + by_inverse_a_funcs, + }, + ); + } + + /** + * Standalone microbench for the disjoint pair-sum kernel: each + * thread reduces 2*S input points to S disjoint pair sums R_k = + * P_{2k} + P_{2k+1}, using one batched fr_inv_by_a per chunk of S. + * Reclaims the 50% kernel-efficiency loss inherent in + * ba_rev_packed_carry (which produces S overlapping pair sums of + * which only S/2 are usable for pair-tree reduction). + */ + public gen_ba_pair_disjoint_bench_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_pair_disjoint_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + const dec = this.decoupledPackUnpackWgsl(); + return mustache.render( + ba_pair_disjoint_bench_shader, + { + workgroup_size, + s, + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, + dec_unpack: dec.unpack, + dec_pack: dec.pack, + recompile: this.recompile, + }, + { + structs, + bigint_funcs, + montgomery_product_funcs: this.mont_product_src, + field_funcs, + fr_pow_funcs, + bigint_by_funcs, + by_inverse_a_funcs, + }, + ); + } + + /** + * Standalone microbench for the recovered ba_rev_packed_carry kernel: + * SoA-packed 8x u32 storage across 4 input planes (A.x, A.y, P.x, P.y), + * strided per-thread access (e = t + i*T), forward prefix-product + + * single fr_inv_by_a + backward peel with resident-accumulator + * load-carry (A_{i+1} := P_i). The canonical kernel that hit ~22 + * ns/pair on M2 / Chrome 148. + */ + public gen_ba_rev_packed_carry_bench_shader(workgroup_size: number, s: number): string { + if (workgroup_size <= 0 || s <= 0 || !Number.isInteger(workgroup_size) || !Number.isInteger(s)) { + throw new Error(`gen_ba_rev_packed_carry_bench_shader: workgroup_size (${workgroup_size}) and s (${s}) must be positive integers`); + } + const dec = this.decoupledPackUnpackWgsl(); + return mustache.render( + ba_rev_packed_carry_bench_shader, + { + workgroup_size, + s, + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, + dec_unpack: dec.unpack, + dec_pack: dec.pack, + recompile: this.recompile, + }, + { + structs, + bigint_funcs, + montgomery_product_funcs: this.mont_product_src, + field_funcs, + fr_pow_funcs, + bigint_by_funcs, + by_inverse_a_funcs, + }, + ); + } + public gen_batch_affine_init_shader(workgroup_size: number, packed = false): string { const dec = this.decoupledPackUnpackWgsl(); return mustache.render( @@ -877,6 +1384,57 @@ ${packLines.join('\n')} ); } + /** + * Workgroup-scan fused batch-affine round kernel (v2). Mirrors the + * `bench_batch_affine` design (TPB threads cooperating over a + * BATCH_SIZE=TPB*BS pair slice with one fr_inv_by_a per workgroup) + * but with bucket-indirect loads via pair_target_meta. Every + * field-element variable is `PackedField`; no per-load unpack at + * the kernel level. + */ + public gen_batch_affine_fused_wg_scan_shader(tpb: number, bs: number): string { + if (tpb <= 0 || bs <= 0 || !Number.isInteger(tpb) || !Number.isInteger(bs)) { + throw new Error(`gen_batch_affine_fused_wg_scan_shader: tpb (${tpb}) and bs (${bs}) must be positive integers`); + } + if ((tpb & (tpb - 1)) !== 0) { + throw new Error(`gen_batch_affine_fused_wg_scan_shader: tpb (${tpb}) must be a power of two (Hillis-Steele scan)`); + } + const batch_size = tpb * bs; + const dec = this.decoupledPackUnpackWgsl(); + return mustache.render( + batch_affine_fused_wg_scan_shader, + { + tpb, + bs, + batch_size, + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, + dec_unpack: dec.unpack, + dec_pack: dec.pack, + recompile: this.recompile, + }, + { + structs, + bigint_funcs, + montgomery_product_funcs: this.mont_product_src, + field_funcs, + fr_pow_funcs, + bigint_by_funcs, + by_inverse_a_funcs, + packed_field_funcs, + }, + ); + } + public gen_batch_affine_finalize_shader(workgroup_size: number, num_csr_cols: number): string { return mustache.render( batch_affine_finalize_shader, diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts b/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts new file mode 100644 index 000000000000..51ae0e363a34 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/cuzk/smvp_v2_pair_tree.ts @@ -0,0 +1,601 @@ +/// + +/** + * v2 bin-packed pair-tree MSM bucket-accumulate orchestrator. + * + * Drop-in replacement for the cuZK round-loop (`smvp_batch_affine_gpu`'s + * schedule + batch_inverse_parallel + apply_scatter per round) for the + * Pippenger bucket-accumulate phase. For each window: + * + * csr_to_v2_meta row_ptr -> per-bucket count + offset + * csr_to_v2_active_sums val_idx + cached bases -> bucket-major + * active_sums (combined SoA, packed 8x u32) + * for level in 0..max_levels: + * ba_planner_v2_prod counts/offsets -> chunk_plan / scatter_plan + * / carry_plan + new_counts/new_offsets + + * totals (incl. per-level dispatch_args + * triples at totals[4..6] and totals[7..9]) + * ba_marshal_pairs_prod indirect dispatch off totals[4..6] + * ba_pair_disjoint_tree_prod indirect dispatch off totals[4..6] + * ba_scatter_pairs_prod indirect dispatch off totals[4..6] + * ba_carry_copy_prod indirect dispatch off totals[7..9] + * v2_to_running final active_sums slot per non-empty + * bucket -> running_x / running_y / + * bucket_active in production layout + * + * The prod kernels read num_chunks (= ceil(total_pairs / S)) and + * num_carries from the planner's totals storage output; the host + * dispatches them via dispatchWorkgroupsIndirect. Each level's + * downstream dispatch is sized to EXACTLY the chunks/carries the + * planner produced — zero pad-chunk waste, the runtime advantage + * over the pad-fill alternative. + * + * All dispatches for all windows are recorded onto one command encoder + * and submitted once. Submit overhead is paid once per MSM, not once + * per window or once per level. + * + * Layout boundaries: + * active_sums (combined SoA, one buffer per ping-pong copy): + * plane 0 (x) at vec4 indices [0, PG * M) + * plane 1 (y) at vec4 indices [PG * M, 2 * PG * M) + * element layout: PG=2 vec4 at [PG*elem, PG*elem+1] + * M = input_size + 2 (last 2 slots are the pad pair the planner + * emits into the chunk-tail for filler pairs — we initialise the + * pad pair once at orchestrator start with distinct-x Montgomery- + * form values so the disjoint kernel's lean affine add is well- + * defined on pad chunks even though they get scattered to the + * discard slot) + * running_x / running_y (production, separate buffers): + * packed 8x u32 = 2 vec4 per (subtask, bucket_local), at + * [PG * bucket_global, PG * bucket_global + 1] with + * bucket_global = subtask_idx * num_columns + bucket_local + * bucket_active: u32 per bucket_global + * v2_to_running binds running_x / running_y / bucket_active with a + * subtask_idx * num_columns byte offset so a single per-window + * dispatch lands the result at the right slab. + */ + +import { ShaderManager } from './shader_manager.js'; + +const PG = 2; +const PG_VEC4_BYTES = 16; +const ELEMENT_BYTES = PG * PG_VEC4_BYTES; + +export interface SmvpV2PairTreeOptions { + device: GPUDevice; + shaderManager: ShaderManager; + num_subtasks: number; + num_columns: number; + input_size: number; + + s?: number; + tpb?: number; + per_thread?: number; + wgi?: number; + max_levels?: number; + + val_idx_buf: GPUBuffer; + row_ptr_buf: GPUBuffer; + point_x_buf: GPUBuffer; + point_y_buf: GPUBuffer; + + running_x_buf: GPUBuffer; + running_y_buf: GPUBuffer; + bucket_active_buf: GPUBuffer; +} + +export interface SmvpV2PairTreeStats { + levels_per_window: number; + num_subtasks: number; + num_columns: number; + total_passes: number; + gpu_wall_ms: number; +} + +function roStorageEntry(binding: number): GPUBindGroupLayoutEntry { + return { binding, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }; +} +function rwStorageEntry(binding: number): GPUBindGroupLayoutEntry { + return { binding, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }; +} +function uniformEntry(binding: number): GPUBindGroupLayoutEntry { + return { binding, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }; +} + +async function compilePipeline( + device: GPUDevice, + layout: GPUBindGroupLayout, + code: string, + key: string, +): Promise { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + const errs: string[] = []; + for (const m of info.messages) { + const line = `[smvp-v2 ${key}] ${m.type}: ${m.message} (line ${m.lineNum}, col ${m.linePos})`; + if (m.type === 'error') { + console.error(line); + errs.push(line); + } else { + console.warn(line); + } + } + if (errs.length > 0) { + throw new Error(`WGSL compile failed for ${key}: ${errs.slice(0, 4).join(' | ')}`); + } + return device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); +} + +interface Pipelines { + csrMeta: GPUComputePipeline; + csrActive: GPUComputePipeline; + planner: GPUComputePipeline; + marshal: GPUComputePipeline; + disjoint: GPUComputePipeline; + scatter: GPUComputePipeline; + carry: GPUComputePipeline; + v2ToRunning: GPUComputePipeline; + layouts: { + meta: GPUBindGroupLayout; + active: GPUBindGroupLayout; + planner: GPUBindGroupLayout; + marshal: GPUBindGroupLayout; + disjoint: GPUBindGroupLayout; + scatter: GPUBindGroupLayout; + carry: GPUBindGroupLayout; + v2Run: GPUBindGroupLayout; + }; +} + +async function compileAll( + device: GPUDevice, + sm: ShaderManager, + wgi: number, + s: number, + tpb: number, + per_thread: number, +): Promise { + const layouts: Pipelines['layouts'] = { + meta: device.createBindGroupLayout({ + entries: [roStorageEntry(0), rwStorageEntry(1), rwStorageEntry(2), uniformEntry(3)], + }), + active: device.createBindGroupLayout({ + entries: [ + roStorageEntry(0), + roStorageEntry(1), + roStorageEntry(2), + rwStorageEntry(3), + uniformEntry(4), + ], + }), + planner: device.createBindGroupLayout({ + entries: [ + roStorageEntry(0), + roStorageEntry(1), + rwStorageEntry(2), + rwStorageEntry(3), + rwStorageEntry(4), + rwStorageEntry(5), + rwStorageEntry(6), + rwStorageEntry(7), + uniformEntry(8), + ], + }), + marshal: device.createBindGroupLayout({ + entries: [roStorageEntry(0), roStorageEntry(1), rwStorageEntry(2), roStorageEntry(3), uniformEntry(4)], + }), + disjoint: device.createBindGroupLayout({ + entries: [roStorageEntry(0), roStorageEntry(1), rwStorageEntry(2), roStorageEntry(3)], + }), + scatter: device.createBindGroupLayout({ + entries: [roStorageEntry(0), roStorageEntry(1), rwStorageEntry(2), roStorageEntry(3), uniformEntry(4)], + }), + carry: device.createBindGroupLayout({ + entries: [roStorageEntry(0), roStorageEntry(1), rwStorageEntry(2), roStorageEntry(3), uniformEntry(4)], + }), + v2Run: device.createBindGroupLayout({ + entries: [ + roStorageEntry(0), + roStorageEntry(1), + roStorageEntry(2), + rwStorageEntry(3), + rwStorageEntry(4), + rwStorageEntry(5), + uniformEntry(6), + ], + }), + }; + + const [csrMeta, csrActive, planner, marshal, disjoint, scatter, carry, v2ToRunning] = await Promise.all([ + compilePipeline(device, layouts.meta, sm.gen_csr_to_v2_meta_shader(wgi), `csr-meta-wg${wgi}`), + compilePipeline(device, layouts.active, sm.gen_csr_to_v2_active_sums_shader(wgi), `csr-active-wg${wgi}`), + compilePipeline( + device, + layouts.planner, + sm.gen_ba_planner_v2_prod_shader(tpb, per_thread, s, wgi, 64), + `planner-v2-prod-T${tpb}-P${per_thread}-S${s}-W${wgi}`, + ), + compilePipeline(device, layouts.marshal, sm.gen_ba_marshal_pairs_prod_shader(wgi, s), `marshal-prod-W${wgi}-S${s}`), + compilePipeline(device, layouts.disjoint, sm.gen_ba_pair_disjoint_tree_prod_shader(wgi, s), `disjoint-prod-W${wgi}-S${s}`), + compilePipeline(device, layouts.scatter, sm.gen_ba_scatter_pairs_prod_shader(wgi, s), `scatter-prod-W${wgi}-S${s}`), + compilePipeline(device, layouts.carry, sm.gen_ba_carry_copy_prod_shader(wgi), `carry-prod-W${wgi}`), + compilePipeline(device, layouts.v2Run, sm.gen_v2_to_running_shader(wgi), `v2-to-running-wg${wgi}`), + ]); + return { csrMeta, csrActive, planner, marshal, disjoint, scatter, carry, v2ToRunning, layouts }; +} + +interface Scratch { + activeA: GPUBuffer; + activeB: GPUBuffer; + chainBuf: GPUBuffer; + tempOut: GPUBuffer; + countsA: GPUBuffer; + countsB: GPUBuffer; + offsetsA: GPUBuffer; + offsetsB: GPUBuffer; + perLevelChunkPlan: GPUBuffer[]; + perLevelScatterPlan: GPUBuffer[]; + perLevelCarryPlan: GPUBuffer[]; + perLevelTotals: GPUBuffer[]; + metaParams: GPUBuffer; + activeParams: GPUBuffer; + plannerParams: GPUBuffer; + marshalConsts: GPUBuffer; + scatterConsts: GPUBuffer; + carryConsts: GPUBuffer; + v2RunParams: GPUBuffer; + M: number; + maxChunks: number; + perThread: number; +} + +function allocScratch( + device: GPUDevice, + num_columns: number, + input_size: number, + s: number, + max_levels: number, + tpb: number, + per_thread: number, +): Scratch { + // M = real slots (input_size) + 3 reserved tail slots: pad_left, + // pad_right, discard. Reserved slots aren't touched by the converter + // or by real bucket reductions (which only address [0, total_actives) < + // input_size), so the planner can safely pad-fill the last partial + // chunk of chunk_plan / scatter_plan with these constant indices. + const M = input_size + 3; + const maxChunks = Math.max(1, Math.ceil(input_size / 2 / s) + 1); + + const mk = (bytes: number, extra: GPUBufferUsageFlags = 0): GPUBuffer => + device.createBuffer({ size: bytes, usage: GPUBufferUsage.STORAGE | extra }); + + const activeBytes = 2 * PG * M * PG_VEC4_BYTES; + const activeA = mk(activeBytes, GPUBufferUsage.COPY_DST); + const activeB = mk(activeBytes, GPUBufferUsage.COPY_DST); + + const chainBuf = mk(2 * PG * (2 * s * maxChunks) * PG_VEC4_BYTES); + const tempOut = mk(2 * PG * (s * maxChunks) * PG_VEC4_BYTES); + + const countsBytes = num_columns * 4; + const offsetsBytes = num_columns * 4; + const countsA = mk(countsBytes); + const countsB = mk(countsBytes); + const offsetsA = mk(offsetsBytes); + const offsetsB = mk(offsetsBytes); + + const perLevelChunkPlan: GPUBuffer[] = []; + const perLevelScatterPlan: GPUBuffer[] = []; + const perLevelCarryPlan: GPUBuffer[] = []; + const perLevelTotals: GPUBuffer[] = []; + const chunkPlanBytes = 2 * s * maxChunks * 4; + const scatterPlanBytes = s * maxChunks * 4; + const carryPlanBytes = 2 * num_columns * 4; + const totalsBytes = Math.max(40, Math.ceil(40 / 16) * 16); + for (let lvl = 0; lvl < max_levels; lvl++) { + perLevelChunkPlan.push(mk(chunkPlanBytes)); + perLevelScatterPlan.push(mk(scatterPlanBytes)); + perLevelCarryPlan.push(mk(carryPlanBytes)); + perLevelTotals.push(mk(totalsBytes, GPUBufferUsage.INDIRECT | GPUBufferUsage.COPY_DST)); + } + + const ub = (bytes: number): GPUBuffer => + device.createBuffer({ size: Math.max(16, bytes), usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + const metaParams = ub(16); + const activeParams = ub(16); + const plannerParams = ub(16); + const marshalConsts = ub(16); + const scatterConsts = ub(16); + const carryConsts = ub(16); + const v2RunParams = ub(16); + + return { + activeA, activeB, chainBuf, tempOut, + countsA, countsB, offsetsA, offsetsB, + perLevelChunkPlan, perLevelScatterPlan, perLevelCarryPlan, perLevelTotals, + metaParams, activeParams, plannerParams, + marshalConsts, scatterConsts, carryConsts, v2RunParams, + M, maxChunks, perThread: per_thread, + }; +} + +function destroyScratch(scratch: Scratch): void { + scratch.activeA.destroy(); + scratch.activeB.destroy(); + scratch.chainBuf.destroy(); + scratch.tempOut.destroy(); + scratch.countsA.destroy(); + scratch.countsB.destroy(); + scratch.offsetsA.destroy(); + scratch.offsetsB.destroy(); + for (const b of scratch.perLevelChunkPlan) b.destroy(); + for (const b of scratch.perLevelScatterPlan) b.destroy(); + for (const b of scratch.perLevelCarryPlan) b.destroy(); + for (const b of scratch.perLevelTotals) b.destroy(); + scratch.metaParams.destroy(); + scratch.activeParams.destroy(); + scratch.plannerParams.destroy(); + scratch.marshalConsts.destroy(); + scratch.scatterConsts.destroy(); + scratch.carryConsts.destroy(); + scratch.v2RunParams.destroy(); +} + +/** + * Run the v2 pair-tree MSM bucket-accumulate for ALL pippenger windows + * in a single GPU submit. + * + * On return the caller's running_x / running_y / bucket_active buffers + * hold each bucket's reduced packed point (or 0/inactive marker) ready + * for batch_affine_finalize_collect to consume. The caller's val_idx / + * row_ptr / cached-bases buffers are read-only. + */ +export async function runSmvpV2PairTree(opts: SmvpV2PairTreeOptions): Promise { + const { + device, shaderManager, num_subtasks, num_columns, input_size, + val_idx_buf, row_ptr_buf, point_x_buf, point_y_buf, + running_x_buf, running_y_buf, bucket_active_buf, + } = opts; + const s = opts.s ?? 16; + const tpb = opts.tpb ?? 256; + const per_thread = opts.per_thread ?? Math.max(1, Math.ceil(num_columns / tpb)); + const wgi = opts.wgi ?? 64; + const max_levels = opts.max_levels ?? 8; + if (tpb * per_thread < num_columns) { + throw new Error(`smvp_v2_pair_tree: tpb*per_thread (${tpb}*${per_thread}=${tpb * per_thread}) must be >= num_columns (${num_columns}).`); + } + + const pipelines = await compileAll(device, shaderManager, wgi, s, tpb, per_thread); + const scratch = allocScratch(device, num_columns, input_size, s, max_levels, tpb, per_thread); + const M = scratch.M; + + const padLeft = input_size; + const padRight = input_size + 1; + const discard = input_size + 2; + device.queue.writeBuffer(scratch.metaParams, 0, new Uint32Array([num_columns, num_columns, 0, 0])); + device.queue.writeBuffer(scratch.activeParams, 0, new Uint32Array([input_size, M, 0, 0])); + device.queue.writeBuffer(scratch.plannerParams, 0, new Uint32Array([num_columns, padLeft, padRight, discard])); + device.queue.writeBuffer(scratch.marshalConsts, 0, new Uint32Array([M, 0, 0, 0])); + device.queue.writeBuffer(scratch.scatterConsts, 0, new Uint32Array([M, 0, 0, 0])); + device.queue.writeBuffer(scratch.carryConsts, 0, new Uint32Array([M, M, 0, 0])); + device.queue.writeBuffer(scratch.v2RunParams, 0, new Uint32Array([num_columns, M, 0, 0])); + + // Pad pair init. The planner's tail pad-fill writes chunk_plan entries + // (padLeft, padRight) and scatter_plan entries = discard for the + // unused slots of the last partial chunk. The disjoint kernel will + // then compute a (garbage) affine add of active_sums[padLeft] and + // active_sums[padRight] and scatter the result to + // active_sums_new[discard]. Pad slots must have distinct x so the + // affine-add formula doesn't divide by zero. Discard slot is a + // never-read tail slot reserved beyond the real bucket data. + const padPair = new Uint32Array(2 * PG * 4); + for (let i = 0; i < padPair.length; i++) padPair[i] = (0x9e3779b9 * (i + 1)) >>> 0; + if (padPair[0] === padPair[PG * 4]) padPair[PG * 4] ^= 1; + const planeXPadByteOff = PG * padLeft * PG_VEC4_BYTES; + const planeYPadByteOff = PG * M * PG_VEC4_BYTES + PG * padLeft * PG_VEC4_BYTES; + device.queue.writeBuffer(scratch.activeA, planeXPadByteOff, padPair as BufferSource); + device.queue.writeBuffer(scratch.activeA, planeYPadByteOff, padPair as BufferSource); + device.queue.writeBuffer(scratch.activeB, planeXPadByteOff, padPair as BufferSource); + device.queue.writeBuffer(scratch.activeB, planeYPadByteOff, padPair as BufferSource); + + const encoder = device.createCommandEncoder(); + let totalPasses = 0; + const directPass = (pipe: GPUComputePipeline, bind: GPUBindGroup, x: number, y = 1, z = 1): void => { + const pass = encoder.beginComputePass(); + pass.setPipeline(pipe); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroups(x, y, z); + pass.end(); + totalPasses++; + }; + const indirectPass = (pipe: GPUComputePipeline, bind: GPUBindGroup, argsBuf: GPUBuffer, byteOffset: number): void => { + const pass = encoder.beginComputePass(); + pass.setPipeline(pipe); + pass.setBindGroup(0, bind); + pass.dispatchWorkgroupsIndirect(argsBuf, byteOffset); + pass.end(); + totalPasses++; + }; + + const valIdxStride = input_size * 4; + const rowPtrStride = (num_columns + 1) * 4; + const runningStride = num_columns * ELEMENT_BYTES; + const bucketActiveStride = num_columns * 4; + + for (let st = 0; st < num_subtasks; st++) { + const valIdxView = { buffer: val_idx_buf, offset: st * valIdxStride, size: valIdxStride } as const; + const rowPtrView = { buffer: row_ptr_buf, offset: st * rowPtrStride, size: rowPtrStride } as const; + + const metaBind = device.createBindGroup({ + layout: pipelines.layouts.meta, + entries: [ + { binding: 0, resource: rowPtrView }, + { binding: 1, resource: { buffer: scratch.countsA } }, + { binding: 2, resource: { buffer: scratch.offsetsA } }, + { binding: 3, resource: { buffer: scratch.metaParams } }, + ], + }); + directPass(pipelines.csrMeta, metaBind, Math.ceil(num_columns / wgi)); + + const activeBind = device.createBindGroup({ + layout: pipelines.layouts.active, + entries: [ + { binding: 0, resource: valIdxView }, + { binding: 1, resource: { buffer: point_x_buf } }, + { binding: 2, resource: { buffer: point_y_buf } }, + { binding: 3, resource: { buffer: scratch.activeA } }, + { binding: 4, resource: { buffer: scratch.activeParams } }, + ], + }); + directPass(pipelines.csrActive, activeBind, Math.ceil(input_size / wgi)); + + let curActive: GPUBuffer = scratch.activeA; + let nextActive: GPUBuffer = scratch.activeB; + let curCounts: GPUBuffer = scratch.countsA; + let curOffsets: GPUBuffer = scratch.offsetsA; + let nextCounts: GPUBuffer = scratch.countsB; + let nextOffsets: GPUBuffer = scratch.offsetsB; + + for (let lvl = 0; lvl < max_levels; lvl++) { + const chunkPlanBuf = scratch.perLevelChunkPlan[lvl]; + const scatterPlanBuf = scratch.perLevelScatterPlan[lvl]; + const carryPlanBuf = scratch.perLevelCarryPlan[lvl]; + const totalsBuf = scratch.perLevelTotals[lvl]; + + const plannerBind = device.createBindGroup({ + layout: pipelines.layouts.planner, + entries: [ + { binding: 0, resource: { buffer: curCounts } }, + { binding: 1, resource: { buffer: curOffsets } }, + { binding: 2, resource: { buffer: chunkPlanBuf } }, + { binding: 3, resource: { buffer: scatterPlanBuf } }, + { binding: 4, resource: { buffer: carryPlanBuf } }, + { binding: 5, resource: { buffer: nextCounts } }, + { binding: 6, resource: { buffer: nextOffsets } }, + { binding: 7, resource: { buffer: totalsBuf } }, + { binding: 8, resource: { buffer: scratch.plannerParams } }, + ], + }); + directPass(pipelines.planner, plannerBind, 1); + + const marshalBind = device.createBindGroup({ + layout: pipelines.layouts.marshal, + entries: [ + { binding: 0, resource: { buffer: chunkPlanBuf } }, + { binding: 1, resource: { buffer: curActive } }, + { binding: 2, resource: { buffer: scratch.chainBuf } }, + { binding: 3, resource: { buffer: totalsBuf } }, + { binding: 4, resource: { buffer: scratch.marshalConsts } }, + ], + }); + indirectPass(pipelines.marshal, marshalBind, totalsBuf, 16); + + const disjointBind = device.createBindGroup({ + layout: pipelines.layouts.disjoint, + entries: [ + { binding: 0, resource: { buffer: scratch.chainBuf } }, + { binding: 1, resource: { buffer: scratch.chainBuf } }, + { binding: 2, resource: { buffer: scratch.tempOut } }, + { binding: 3, resource: { buffer: totalsBuf } }, + ], + }); + indirectPass(pipelines.disjoint, disjointBind, totalsBuf, 16); + + const scatterBind = device.createBindGroup({ + layout: pipelines.layouts.scatter, + entries: [ + { binding: 0, resource: { buffer: scatterPlanBuf } }, + { binding: 1, resource: { buffer: scratch.tempOut } }, + { binding: 2, resource: { buffer: nextActive } }, + { binding: 3, resource: { buffer: totalsBuf } }, + { binding: 4, resource: { buffer: scratch.scatterConsts } }, + ], + }); + indirectPass(pipelines.scatter, scatterBind, totalsBuf, 16); + + const carryBind = device.createBindGroup({ + layout: pipelines.layouts.carry, + entries: [ + { binding: 0, resource: { buffer: carryPlanBuf } }, + { binding: 1, resource: { buffer: curActive } }, + { binding: 2, resource: { buffer: nextActive } }, + { binding: 3, resource: { buffer: totalsBuf } }, + { binding: 4, resource: { buffer: scratch.carryConsts } }, + ], + }); + indirectPass(pipelines.carry, carryBind, totalsBuf, 28); + + [curActive, nextActive] = [nextActive, curActive]; + [curCounts, nextCounts] = [nextCounts, curCounts]; + [curOffsets, nextOffsets] = [nextOffsets, curOffsets]; + } + + const subtaskBucketOff = st * num_columns; + const v2RunBind = device.createBindGroup({ + layout: pipelines.layouts.v2Run, + entries: [ + { binding: 0, resource: { buffer: curActive } }, + { binding: 1, resource: { buffer: curCounts } }, + { binding: 2, resource: { buffer: curOffsets } }, + { binding: 3, resource: { buffer: running_x_buf, offset: subtaskBucketOff * ELEMENT_BYTES, size: runningStride } }, + { binding: 4, resource: { buffer: running_y_buf, offset: subtaskBucketOff * ELEMENT_BYTES, size: runningStride } }, + { binding: 5, resource: { buffer: bucket_active_buf, offset: subtaskBucketOff * 4, size: bucketActiveStride } }, + { binding: 6, resource: { buffer: scratch.v2RunParams } }, + ], + }); + directPass(pipelines.v2ToRunning, v2RunBind, Math.ceil(num_columns / wgi)); + } + + const t0 = performance.now(); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + const gpu_wall_ms = performance.now() - t0; + + destroyScratch(scratch); + + return { + levels_per_window: max_levels, + num_subtasks, + num_columns, + total_passes: totalPasses, + gpu_wall_ms, + }; +} + +export function maxChunksUpperBound(input_size: number, num_columns: number, s: number): number { + return Math.max(1, Math.ceil(input_size / 2 / s) + num_columns); +} + +export const sizes = { + activeSumsBytes(input_size: number): number { + const M = input_size + 2; + return 2 * PG * M * 16; + }, + chainBufBytes(input_size: number, num_columns: number, s: number): number { + const T = maxChunksUpperBound(input_size, num_columns, s); + return 2 * PG * (2 * s * T) * 16; + }, + tempOutBytes(input_size: number, num_columns: number, s: number): number { + const T = maxChunksUpperBound(input_size, num_columns, s); + return 2 * PG * (s * T) * 16; + }, + chunkPlanBytes(input_size: number, num_columns: number, s: number): number { + const T = maxChunksUpperBound(input_size, num_columns, s); + return 2 * s * T * 4; + }, + scatterPlanBytes(input_size: number, num_columns: number, s: number): number { + const T = maxChunksUpperBound(input_size, num_columns, s); + return s * T * 4; + }, + carryPlanBytes(num_columns: number): number { + return 2 * num_columns * 4; + }, + countsBytes(num_columns: number): number { + return num_columns * 4; + }, + offsetsBytes(num_columns: number): number { + return num_columns * 4; + }, +}; diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index 741d3480f737..bef7e380beeb 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1,6 +1,6 @@ // AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT. // Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate. -// 48 shader sources inlined. +// 70 shader sources inlined. /* eslint-disable */ @@ -1351,6 +1351,1870 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const ba_carry_copy_bench = `{{> structs }} + +// Carry-copy kernel for the bin-packed pair-tree MSM bucket-accumulate. +// +// For each carry slot t, copies one packed (x, y) point from +// active_sums_old[carry_plan[2*t + 0]] to +// active_sums_new[carry_plan[2*t + 1]]. +// +// Used when a bucket has an odd active count at the current level: +// floor(N_b / 2) elements get paired and produce floor(N_b / 2) sums +// in the next level, plus the (N_b mod 2 == 1) carry element propagates +// forward unchanged. +// +// Pure memory shuffle, no field arithmetic. +// +// params.x = T (number of carry-copies / threads) +// params.y = M_old (active_sums_old size, vec4-stride scaling) +// params.z = M_new (active_sums_new size, vec4-stride scaling) + +const PG: u32 = 2u; + +@group(0) @binding(0) var carry_plan: array; +@group(0) @binding(1) var active_sums_old: array>; +@group(0) @binding(2) var active_sums_new: array>; +@group(0) @binding(3) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let M_old = params.y; + let M_new = params.z; + let t = gid.x; + if (t >= T) { return; } + + let src_idx = carry_plan[2u * t + 0u]; + let dst_idx = carry_plan[2u * t + 1u]; + + let old_plane_x = 0u * PG * M_old; + let old_plane_y = 1u * PG * M_old; + let new_plane_x = 0u * PG * M_new; + let new_plane_y = 1u * PG * M_new; + + let src_x = old_plane_x + PG * src_idx; + let src_y = old_plane_y + PG * src_idx; + let dst_x = new_plane_x + PG * dst_idx; + let dst_y = new_plane_y + PG * dst_idx; + + active_sums_new[dst_x + 0u] = active_sums_old[src_x + 0u]; + active_sums_new[dst_x + 1u] = active_sums_old[src_x + 1u]; + active_sums_new[dst_y + 0u] = active_sums_old[src_y + 0u]; + active_sums_new[dst_y + 1u] = active_sums_old[src_y + 1u]; + + {{{ recompile }}} +} +`; + +export const ba_carry_copy_prod = `{{> structs }} + +// Carry-copy kernel — prod variant for the v2 pair-tree integration. +// num_carries is read from the planner's totals[1] and dispatch is +// indirect via totals[7..9]. + +const PG: u32 = 2u; + +@group(0) @binding(0) var carry_plan: array; +@group(0) @binding(1) var active_sums_old: array>; +@group(0) @binding(2) var active_sums_new: array>; +@group(0) @binding(3) var totals: array; +@group(0) @binding(4) var consts: vec4; +// consts.x = M_old +// consts.y = M_new + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = totals[1]; + let M_old = consts.x; + let M_new = consts.y; + let t = gid.x; + if (t >= T) { return; } + + let src_idx = carry_plan[2u * t + 0u]; + let dst_idx = carry_plan[2u * t + 1u]; + + let old_plane_x = 0u * PG * M_old; + let old_plane_y = 1u * PG * M_old; + let new_plane_x = 0u * PG * M_new; + let new_plane_y = 1u * PG * M_new; + + let src_x = old_plane_x + PG * src_idx; + let src_y = old_plane_y + PG * src_idx; + let dst_x = new_plane_x + PG * dst_idx; + let dst_y = new_plane_y + PG * dst_idx; + + active_sums_new[dst_x + 0u] = active_sums_old[src_x + 0u]; + active_sums_new[dst_x + 1u] = active_sums_old[src_x + 1u]; + active_sums_new[dst_y + 0u] = active_sums_old[src_y + 0u]; + active_sums_new[dst_y + 1u] = active_sums_old[src_y + 1u]; + + {{{ recompile }}} +} +`; + +export const ba_fused_super_bench = `{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Fused super-kernel for the bin-packed pair-tree MSM bucket-accumulate. +// +// Combines marshal + disjoint + scatter into one kernel. Each thread t +// handles one chunk of S pairs: +// 1. Read 2*S source indices from chunk_plan (idx_l, idx_r per slot). +// 2. Read S destination indices from scatter_plan. +// 3. Load S pair-x values from active_sums_old, compute S dx values +// and forward prefix product, all in registers. +// 4. Single fr_inv_by_a on the prefix product. +// 5. Backward peel: per slot k from S-1 down to 0: +// - load .x and .y for both operands +// - lean affine add -> R_x, R_y +// - write directly to active_sums_new at scatter_plan[t*S + k] +// - update inv for next (smaller-k) iteration +// +// vs v2 (4 kernels: marshal, disjoint, scatter, carry): the chain_buf +// and tempOut scratch buffers are eliminated. All intermediate state +// lives in registers. Per-level dispatch count drops from 4 to 2 +// (fused + carry). +// +// PARAMS: +// params.x = T_chunks (active threads, one per chunk) +// params.y = M_old (active_sums_old vec4-stride length) +// params.z = M_new (active_sums_new vec4-stride length) +// +// Layout (both active_sums buffers): 2 planes (P.x, P.y), PG=2 vec4 per +// element. plane_p flat vec4 base = p * PG * M, element e at offset +// PG * e. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var chunk_plan: array; +@group(0) @binding(1) var scatter_plan: array; +@group(0) @binding(2) var active_sums_old: array>; +@group(0) @binding(3) var active_sums_new: array>; +@group(0) @binding(4) var params: vec4; + +fn load_active_x(idx: u32, M: u32) -> BigInt { + let plane_base = 0u * PG * M; + let base = plane_base + PG * idx; + let q0 = active_sums_old[base + 0u]; + let q1 = active_sums_old[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn load_active_y(idx: u32, M: u32) -> BigInt { + let plane_base = 1u * PG * M; + let base = plane_base + PG * idx; + let q0 = active_sums_old[base + 0u]; + let q1 = active_sums_old[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_active_new(plane: u32, idx: u32, M: u32, val: ptr) { + let plane_base = plane * PG * M; + let base = plane_base + PG * idx; + let w = pack_limbs_to_256(val); + active_sums_new[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + active_sums_new[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let M_old = params.y; + let M_new = params.z; + let t = gid.x; + if (t >= T) { return; } + + let chunk_base = 2u * S * t; + + // Forward: compute S dx values and accumulate prefix product. + // Read pair indices from chunk_plan, load .x for each operand, compute dx. + var pref: array; + var acc: BigInt = get_r(); + for (var k: u32 = 0u; k < S; k = k + 1u) { + let idx_l = chunk_plan[chunk_base + 2u * k + 0u]; + let idx_r = chunk_plan[chunk_base + 2u * k + 1u]; + var p_lx: BigInt = load_active_x(idx_l, M_old); + var p_rx: BigInt = load_active_x(idx_r, M_old); + var dx: BigInt = fr_sub(&p_rx, &p_lx); + if (k == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[k] = acc; + } + + // Single inversion per chunk. + var inv: BigInt = fr_inv_by_a(acc); + + // Backward peel: emit S pair sums, scatter to active_sums_new. + for (var jj: u32 = 0u; jj < S; jj = jj + 1u) { + let k = S - 1u - jj; + let idx_l = chunk_plan[chunk_base + 2u * k + 0u]; + let idx_r = chunk_plan[chunk_base + 2u * k + 1u]; + + var p_lx: BigInt = load_active_x(idx_l, M_old); + var p_ly: BigInt = load_active_y(idx_l, M_old); + var p_rx: BigInt = load_active_x(idx_r, M_old); + var p_ry: BigInt = load_active_y(idx_r, M_old); + + var inv_dx: BigInt; + if (k == 0u) { + inv_dx = inv; + } else { + var pp = pref[k - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda: BigInt = fr_sub(&p_ry, &p_ly); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_lx); + r_x = fr_sub(&r_x, &p_rx); + var r_y: BigInt = fr_sub(&p_lx, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &p_ly); + + let dst_idx = scatter_plan[t * S + k]; + store_active_new(0u, dst_idx, M_new, &r_x); + store_active_new(1u, dst_idx, M_new, &r_y); + + if (k > 0u) { + var dx_back: BigInt = fr_sub(&p_rx, &p_lx); + inv = montgomery_product(&inv, &dx_back); + } + } + + {{{ recompile }}} +} +`; + +export const ba_marshal_chain_bench = `{{> structs }} + +// Marshal kernel for the bench-msm-chain pipeline. Transposes a CSR +// point list (sorted by bucket) into the strided SoA layout the +// ba_rev_packed_carry_bench chain kernel consumes. +// +// Input layout (point_pool): +// 2 planes (P.x, P.y), each PG=2 vec4 per element, params.y elements total. +// Plane p at point idx i: vec4 indices p*PG*N + PG*i + {0,1}. +// Convention: point_pool[0] is the "decoy" — used as the seed for every +// chunk so the chain kernel's first dx (= P_0.x - seed.x) is well- +// defined. csr_indices values are in [1, N), never 0. +// +// Output layout (chain_buf): +// 4 planes (A.x, A.y, P.x, P.y), each PG=2 vec4 per element, T*S +// elements per plane. Plane p at strided element e = t + i*T: vec4 +// indices p*PG*(T*S) + PG*e + {0,1}. +// +// Per chunk-thread t: +// - csr_start = chunk_plan[2*t + 1] (chunk_plan[2*t] = bucket_id, unused here) +// - Seed at index t (planes 0,1) := point_pool[0] (universal decoy) +// - For i in 0..S: P_i at index e = t + i*T (planes 2,3) +// := point_pool[csr_indices[csr_start + i]] +// +// The chain kernel then produces S pair-sums per chunk. The S/2 odd- +// indexed outputs (R_1, R_3, ..., R_{S-1}) are disjoint pair sums of +// {P_0..P_{S-1}}; the even outputs (R_0, R_2, ...) incorporate the +// decoy or share a P with the next odd output and are discarded by the +// subsequent reduce pass. +// +// Pure memory-shuffle kernel: no field arithmetic. Reads are coalesced +// because consecutive threads t, t+1 read adjacent csr_indices entries +// and the gathered point coords are written to adjacent vec4 slots +// (PG*e for e=t, t+1, ...). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var csr_indices: array; +@group(0) @binding(1) var chunk_plan: array; +@group(0) @binding(2) var point_pool: array>; +@group(0) @binding(3) var chain_buf: array>; +@group(0) @binding(4) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let N = params.y; + let t = gid.x; + if (t >= T) { return; } + + let csr_start = chunk_plan[2u * t + 1u]; + + let chain_N = T * S; + let chain_plane = PG * chain_N; + let chain_ax_base = 0u * chain_plane; + let chain_ay_base = 1u * chain_plane; + let chain_px_base = 2u * chain_plane; + let chain_py_base = 3u * chain_plane; + + let pool_plane = PG * N; + let pool_px_base = 0u * pool_plane; + let pool_py_base = 1u * pool_plane; + + // Seed (A.x, A.y at index t) := point_pool[0] (decoy). + let decoy_x_off = pool_px_base + PG * 0u; + let decoy_y_off = pool_py_base + PG * 0u; + let seed_x_off = chain_ax_base + PG * t; + let seed_y_off = chain_ay_base + PG * t; + chain_buf[seed_x_off + 0u] = point_pool[decoy_x_off + 0u]; + chain_buf[seed_x_off + 1u] = point_pool[decoy_x_off + 1u]; + chain_buf[seed_y_off + 0u] = point_pool[decoy_y_off + 0u]; + chain_buf[seed_y_off + 1u] = point_pool[decoy_y_off + 1u]; + + // Gather S points from csr_indices[csr_start..csr_start+S] into the + // strided P-planes at indices e = t + i*T for i in 0..S. + for (var i = 0u; i < S; i = i + 1u) { + let pt_idx = csr_indices[csr_start + i]; + let e = t + i * T; + let pool_x_off = pool_px_base + PG * pt_idx; + let pool_y_off = pool_py_base + PG * pt_idx; + let chain_px_off = chain_px_base + PG * e; + let chain_py_off = chain_py_base + PG * e; + chain_buf[chain_px_off + 0u] = point_pool[pool_x_off + 0u]; + chain_buf[chain_px_off + 1u] = point_pool[pool_x_off + 1u]; + chain_buf[chain_py_off + 0u] = point_pool[pool_y_off + 0u]; + chain_buf[chain_py_off + 1u] = point_pool[pool_y_off + 1u]; + } + + {{{ recompile }}} +} +`; + +export const ba_marshal_pairs_bench = `{{> structs }} + +// Marshal kernel for the bin-packed pair-tree MSM bucket-accumulate. +// +// Reads (idx_l, idx_r) operand indices per pair from chunk_plan, +// fetches the corresponding packed 8x u32 points from an active_sums +// buffer (2-plane SoA), and writes them into the disjoint kernel's +// strided input layout. +// +// Used both at level 0 (active_sums = bucket-sorted point pool) and +// at levels 1+ (active_sums = previous level's pair-sum + carry +// outputs). The kernel is bucket-agnostic; the planner has packed +// each chunk's S pairs from whatever buckets fit, and chunk_plan +// encodes the operand source indices. +// +// chunk_plan layout: 2 * S u32 per chunk +// chunk_plan[2 * (t * S + k) + 0] = idx_left (active_sums index) +// chunk_plan[2 * (t * S + k) + 1] = idx_right (active_sums index) +// +// active_sums layout: 2 planes (P.x, P.y), PG=2 vec4 per element, +// M_in elements per plane (params.y). +// +// chain_buf layout: 2 planes (P.x, P.y), PG=2 vec4 per element, +// 2 * S * T elements per plane. Slot (t, 2k+0) holds left, slot +// (t, 2k+1) holds right at the disjoint kernel's strided positions +// e = t + i * T for i = 2k, 2k+1. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var chunk_plan: array; +@group(0) @binding(1) var active_sums: array>; +@group(0) @binding(2) var chain_buf: array>; +@group(0) @binding(3) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let M_in = params.y; + let t = gid.x; + if (t >= T) { return; } + + let chain_N = 2u * S * T; + let chain_plane_x = 0u * PG * chain_N; + let chain_plane_y = 1u * PG * chain_N; + + let active_plane_x = 0u * PG * M_in; + let active_plane_y = 1u * PG * M_in; + + let chunk_base = 2u * S * t; + for (var k: u32 = 0u; k < S; k = k + 1u) { + let idx_l = chunk_plan[chunk_base + 2u * k + 0u]; + let idx_r = chunk_plan[chunk_base + 2u * k + 1u]; + + let e_l = t + (2u * k + 0u) * T; + let e_r = t + (2u * k + 1u) * T; + + let src_lx = active_plane_x + PG * idx_l; + let src_ly = active_plane_y + PG * idx_l; + let src_rx = active_plane_x + PG * idx_r; + let src_ry = active_plane_y + PG * idx_r; + + let dst_lx = chain_plane_x + PG * e_l; + let dst_ly = chain_plane_y + PG * e_l; + let dst_rx = chain_plane_x + PG * e_r; + let dst_ry = chain_plane_y + PG * e_r; + + chain_buf[dst_lx + 0u] = active_sums[src_lx + 0u]; + chain_buf[dst_lx + 1u] = active_sums[src_lx + 1u]; + chain_buf[dst_ly + 0u] = active_sums[src_ly + 0u]; + chain_buf[dst_ly + 1u] = active_sums[src_ly + 1u]; + chain_buf[dst_rx + 0u] = active_sums[src_rx + 0u]; + chain_buf[dst_rx + 1u] = active_sums[src_rx + 1u]; + chain_buf[dst_ry + 0u] = active_sums[src_ry + 0u]; + chain_buf[dst_ry + 1u] = active_sums[src_ry + 1u]; + } + + {{{ recompile }}} +} +`; + +export const ba_marshal_pairs_prod = `{{> structs }} + +// Marshal kernel — prod variant for the v2 pair-tree integration. +// +// Same indexing math as ba_marshal_pairs_bench. The only structural +// change: the per-level T (= num_chunks) is read from the planner's +// totals[3] storage output instead of a host-set uniform, and the +// host dispatches via dispatchWorkgroupsIndirect(totals, 16). This +// dispatches exactly ceil(num_chunks / WG) workgroups so no pad +// chunks are computed. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var chunk_plan: array; +@group(0) @binding(1) var active_sums: array>; +@group(0) @binding(2) var chain_buf: array>; +@group(0) @binding(3) var totals: array; +@group(0) @binding(4) var consts: vec4; +// consts.x = M_in + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = totals[3]; + let M_in = consts.x; + let t = gid.x; + if (t >= T) { return; } + + let chain_N = 2u * S * T; + let chain_plane_x = 0u * PG * chain_N; + let chain_plane_y = 1u * PG * chain_N; + + let active_plane_x = 0u * PG * M_in; + let active_plane_y = 1u * PG * M_in; + + let chunk_base = 2u * S * t; + for (var k: u32 = 0u; k < S; k = k + 1u) { + let idx_l = chunk_plan[chunk_base + 2u * k + 0u]; + let idx_r = chunk_plan[chunk_base + 2u * k + 1u]; + + let e_l = t + (2u * k + 0u) * T; + let e_r = t + (2u * k + 1u) * T; + + let src_lx = active_plane_x + PG * idx_l; + let src_ly = active_plane_y + PG * idx_l; + let src_rx = active_plane_x + PG * idx_r; + let src_ry = active_plane_y + PG * idx_r; + + let dst_lx = chain_plane_x + PG * e_l; + let dst_ly = chain_plane_y + PG * e_l; + let dst_rx = chain_plane_x + PG * e_r; + let dst_ry = chain_plane_y + PG * e_r; + + chain_buf[dst_lx + 0u] = active_sums[src_lx + 0u]; + chain_buf[dst_lx + 1u] = active_sums[src_lx + 1u]; + chain_buf[dst_ly + 0u] = active_sums[src_ly + 0u]; + chain_buf[dst_ly + 1u] = active_sums[src_ly + 1u]; + chain_buf[dst_rx + 0u] = active_sums[src_rx + 0u]; + chain_buf[dst_rx + 1u] = active_sums[src_rx + 1u]; + chain_buf[dst_ry + 0u] = active_sums[src_ry + 0u]; + chain_buf[dst_ry + 1u] = active_sums[src_ry + 1u]; + } + + {{{ recompile }}} +} +`; + +export const ba_marshal_tree_l0_bench = `{{> structs }} + +// Marshal kernel for the bench-msm-tree pair-tree pipeline: transposes +// a CSR-sorted point index list into the 2-plane strided SoA layout +// the ba_pair_disjoint_tree kernel consumes at level 0. Pure memory +// shuffle, no field arithmetic. +// +// Input (point_pool): +// 2 planes (P.x, P.y), each PG=2 vec4 per element, N pool elements. +// Plane p flat vec4 indices: p*PG*N + PG*i + {0,1}. +// +// Output (chain_buf): +// 2 planes (P.x, P.y), each PG=2 vec4 per element, 2*S*T elements +// per plane. Plane p at strided element e = t + i*T: vec4 indices +// p*PG*(2*S*T) + PG*e + {0,1}. +// +// Per chunk-thread t with CSR slice [csr_start, csr_start + 2*S): +// For i in 0..2*S: +// pt_idx = csr_indices[csr_start + i] +// copy point_pool[pt_idx] (P.x, P.y) into chain_buf at e = t + i*T + +const S: u32 = {{ s }}u; +const TWOS: u32 = 2u * S; +const PG: u32 = 2u; + +@group(0) @binding(0) var csr_indices: array; +@group(0) @binding(1) var chunk_plan: array; +@group(0) @binding(2) var point_pool: array>; +@group(0) @binding(3) var chain_buf: array>; +@group(0) @binding(4) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let N = params.y; + let t = gid.x; + if (t >= T) { return; } + + let csr_start = chunk_plan[2u * t + 1u]; + + let chain_N = TWOS * T; + let chain_plane = PG * chain_N; + let chain_px_base = 0u * chain_plane; + let chain_py_base = 1u * chain_plane; + + let pool_plane = PG * N; + let pool_px_base = 0u * pool_plane; + let pool_py_base = 1u * pool_plane; + + for (var i: u32 = 0u; i < TWOS; i = i + 1u) { + let pt_idx = csr_indices[csr_start + i]; + let e = t + i * T; + let pool_x_off = pool_px_base + PG * pt_idx; + let pool_y_off = pool_py_base + PG * pt_idx; + let chain_px_off = chain_px_base + PG * e; + let chain_py_off = chain_py_base + PG * e; + chain_buf[chain_px_off + 0u] = point_pool[pool_x_off + 0u]; + chain_buf[chain_px_off + 1u] = point_pool[pool_x_off + 1u]; + chain_buf[chain_py_off + 0u] = point_pool[pool_y_off + 0u]; + chain_buf[chain_py_off + 1u] = point_pool[pool_y_off + 1u]; + } + + {{{ recompile }}} +} +`; + +export const ba_pair_disjoint_bench = `{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Disjoint pair-sum kernel — each thread reduces 2*S input points to S +// disjoint pair sums R_k = P_{2k} + P_{2k+1} (k in 0..S) using the +// same forward-prefix / single-inversion / backward-peel batched- +// inverse pattern as ba_rev_packed_carry, but with NO load-carry +// overlap. Every kernel-output is a distinct pair sum suitable as +// input to the next level of a pair-tree reduction — closes the 50% +// kernel-efficiency loss inherent in the streaming chain kernel. +// +// Storage: SoA-packed 8x u32 per field (PG=2 vec4/elem). +// Input planes (binding 0): +// plane 0 (P.x): PG * N_in vec4, N_in = 2*S*T +// plane 1 (P.y): PG * N_in vec4 +// Output planes (binding 2): +// plane 0 (R.x): PG * N_out vec4, N_out = S*T +// plane 1 (R.y): PG * N_out vec4 +// +// Thread t reads P_i = (inp[plane c at index t + i*T] : c in {0,1}) for +// i in 0..2S (strided => coalesced). Pair k pairs adjacent strided +// slots: (P_{2k}, P_{2k+1}). Output R_k is written at index t + k*T in +// plane c of outp (also strided, coalesced). +// +// dx values dx_k = P_{2k+1}.x - P_{2k}.x are all mutually independent +// (no shared inputs across k), so the standard Montgomery batched +// inverse trick applies as-is: ONE fr_inv_by_a per chunk of S. +// +// Same Karatsuba+Yuval montmul and BY-safegcd fr_inv_by_a as the +// production stack and the chain kernel. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var inp: array>; +@group(0) @binding(1) var unused: array>; +@group(0) @binding(2) var outp: array>; +@group(0) @binding(3) var params: vec4; + +fn load_in(plane: u32, t: u32, i: u32, T: u32, N_in: u32) -> BigInt { + let plane_base = plane * PG * N_in; + let base = plane_base + PG * (t + i * T); + let q0 = inp[base + 0u]; + let q1 = inp[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_out(plane: u32, t: u32, k: u32, T: u32, N_out: u32, val: ptr) { + let plane_base = plane * PG * N_out; + let base = plane_base + PG * (t + k * T); + let w = pack_limbs_to_256(val); + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let N_in = params.x; + let T = params.y; + let N_out = N_in / 2u; + + let t = gid.x; + if (t >= T) { return; } + + // Forward: prefix product of S independent dx values. + var pref: array; + var acc: BigInt = get_r(); + for (var k: u32 = 0u; k < S; k = k + 1u) { + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T, N_in); + var dx: BigInt = fr_sub(&p_rx, &p_lx); + if (k == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[k] = acc; + } + + // One BY-safegcd inversion amortised over all S pair sums. + var inv: BigInt = fr_inv_by_a(acc); + + // Backward peel: emit S disjoint pair sums. + for (var jj: u32 = 0u; jj < S; jj = jj + 1u) { + let k = S - 1u - jj; + + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T, N_in); + var p_ly: BigInt = load_in(1u, t, 2u * k + 0u, T, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T, N_in); + var p_ry: BigInt = load_in(1u, t, 2u * k + 1u, T, N_in); + + var inv_dx: BigInt; + if (k == 0u) { + inv_dx = inv; + } else { + var pp = pref[k - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda: BigInt = fr_sub(&p_ry, &p_ly); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_lx); + r_x = fr_sub(&r_x, &p_rx); + var r_y: BigInt = fr_sub(&p_lx, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &p_ly); + + store_out(0u, t, k, T, N_out, &r_x); + store_out(1u, t, k, T, N_out, &r_y); + + // Advance inv to 1/pref[k-1] for the next (smaller) iteration. + if (k > 0u) { + var dx_back: BigInt = fr_sub(&p_rx, &p_lx); + inv = montgomery_product(&inv, &dx_back); + } + } + + {{{ recompile }}} +} +`; + +export const ba_pair_disjoint_tree_bench = `{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Disjoint pair-sum kernel — tree variant. Each thread reduces 2*S +// input points to S disjoint pair sums R_k = P_{2k} + P_{2k+1}, using +// one batched fr_inv_by_a per chunk of S. +// +// vs ba_pair_disjoint_bench: writes outputs in the LAYOUT THE NEXT +// PAIR-TREE LEVEL EXPECTS AS INPUT, eliminating the need for an +// intervening marshal/reshuffle dispatch between levels. +// +// Strided read at level k: thread t reads input slot i at flat +// in_pos(t, i) = t + i * T_curr (i in [0, 2*S)) +// +// Strided write that next level reads correctly: thread t writes +// output slot i at flat +// out_pos(t, i) = (t >> 1) + (i + S * (t & 1)) * (T_curr >> 1) +// +// Derivation: next level uses T_next = T_curr / 2 threads. For +// next-level thread t_n = t >> 1 to read its 2*S inputs in the right +// pair-tree order (first S from prev thread (2*t_n), next S from prev +// thread (2*t_n + 1)), the current level's output slots interleave: +// odd-t writes go into the upper-S input slots of the next level's +// thread (t >> 1), even-t into the lower-S slots. +// +// This preserves the per-bucket-pair invariant: at every level, the +// disjoint pairs (P_{2j}, P_{2j+1}) belong to the same bucket pool, +// so the lean affine formula is always combining points whose dx is +// well-defined. +// +// PARAMS: +// params.x = N_in = 2 * S * T_curr (total input elements per plane) +// params.y = T_curr +// +// LAYOUT (both input and output buffers): +// 2 planes (P.x, P.y), PG=2 vec4 per element. +// Plane p flat index for vec4 access: p * PG * N_buf + PG * e + {0,1} +// where N_buf is the elements-per-plane for that buffer. +// Input buffer's N_buf = 2 * S * T_curr (= N_in). +// Output buffer's N_buf = S * T_curr (= N_in / 2). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var inp: array>; +@group(0) @binding(1) var unused: array>; +@group(0) @binding(2) var outp: array>; +@group(0) @binding(3) var params: vec4; + +fn load_in(plane: u32, t: u32, i: u32, T: u32, N_in: u32) -> BigInt { + let plane_base = plane * PG * N_in; + let base = plane_base + PG * (t + i * T); + let q0 = inp[base + 0u]; + let q1 = inp[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_out_tree(plane: u32, t: u32, k: u32, T_curr: u32, N_out: u32, val: ptr) { + // Tree write: out_pos(t, k) = (t >> 1) + (k + S * (t & 1)) * (T_curr >> 1) + // Lands in next-level strided read at index (t >> 1) with slot + // (k + S * (t & 1)). + let t_next = t >> 1u; + let slot_in_next = k + S * (t & 1u); + let T_next = T_curr >> 1u; + let plane_base = plane * PG * N_out; + let elem = t_next + slot_in_next * T_next; + let base = plane_base + PG * elem; + let w = pack_limbs_to_256(val); + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn store_out_simple(plane: u32, t: u32, k: u32, T_curr: u32, N_out: u32, val: ptr) { + // Final-level simple strided write: out_pos(t, k) = t + k * T_curr. + // Used when there is no next pair-tree level (T_curr == 1 thread, or + // the host indicates this is the last reduction step). + let plane_base = plane * PG * N_out; + let elem = t + k * T_curr; + let base = plane_base + PG * elem; + let w = pack_limbs_to_256(val); + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let N_in = params.x; + let T_curr = params.y; + let final_flag = params.z; // non-zero => use simple strided write + let N_out = N_in / 2u; + + let t = gid.x; + if (t >= T_curr) { return; } + + var pref: array; + var acc: BigInt = get_r(); + for (var k: u32 = 0u; k < S; k = k + 1u) { + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in); + var dx: BigInt = fr_sub(&p_rx, &p_lx); + if (k == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[k] = acc; + } + + var inv: BigInt = fr_inv_by_a(acc); + + for (var jj: u32 = 0u; jj < S; jj = jj + 1u) { + let k = S - 1u - jj; + + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in); + var p_ly: BigInt = load_in(1u, t, 2u * k + 0u, T_curr, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in); + var p_ry: BigInt = load_in(1u, t, 2u * k + 1u, T_curr, N_in); + + var inv_dx: BigInt; + if (k == 0u) { + inv_dx = inv; + } else { + var pp = pref[k - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda: BigInt = fr_sub(&p_ry, &p_ly); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_lx); + r_x = fr_sub(&r_x, &p_rx); + var r_y: BigInt = fr_sub(&p_lx, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &p_ly); + + if (final_flag != 0u) { + store_out_simple(0u, t, k, T_curr, N_out, &r_x); + store_out_simple(1u, t, k, T_curr, N_out, &r_y); + } else { + store_out_tree(0u, t, k, T_curr, N_out, &r_x); + store_out_tree(1u, t, k, T_curr, N_out, &r_y); + } + + if (k > 0u) { + var dx_back: BigInt = fr_sub(&p_rx, &p_lx); + inv = montgomery_product(&inv, &dx_back); + } + } + + {{{ recompile }}} +} +`; + +export const ba_pair_disjoint_tree_prod = `{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Disjoint pair-sum kernel — prod variant for the v2 pair-tree +// integration. Same disjoint pair-sum math as +// ba_pair_disjoint_tree_bench (suffix-product single fr_inv_by_a per +// chunk + lean affine add); the per-level T (= num_chunks) is read +// from the planner's totals[3] storage output and the dispatch happens +// indirectly so only real chunks run. Always uses the final-mode +// strided write (matches what ba_scatter_pairs_prod expects). +// +// LAYOUT: same as the bench variant. Combined-SoA input/output (2 +// planes, PG=2 vec4 per element, plane-major then element-major then +// vec4 within an element). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var inp: array>; +@group(0) @binding(1) var unused: array>; +@group(0) @binding(2) var outp: array>; +@group(0) @binding(3) var totals: array; + +fn load_in(plane: u32, t: u32, i: u32, T: u32, N_in: u32) -> BigInt { + let plane_base = plane * PG * N_in; + let base = plane_base + PG * (t + i * T); + let q0 = inp[base + 0u]; + let q1 = inp[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_out_simple(plane: u32, t: u32, k: u32, T_curr: u32, N_out: u32, val: ptr) { + let plane_base = plane * PG * N_out; + let elem = t + k * T_curr; + let base = plane_base + PG * elem; + let w = pack_limbs_to_256(val); + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T_curr = totals[3]; + let N_in = 2u * S * T_curr; + let N_out = S * T_curr; + + let t = gid.x; + if (t >= T_curr) { return; } + + var pref: array; + var acc: BigInt = get_r(); + for (var k: u32 = 0u; k < S; k = k + 1u) { + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in); + var dx: BigInt = fr_sub(&p_rx, &p_lx); + if (k == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[k] = acc; + } + + var inv: BigInt = fr_inv_by_a(acc); + + for (var jj: u32 = 0u; jj < S; jj = jj + 1u) { + let k = S - 1u - jj; + + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in); + var p_ly: BigInt = load_in(1u, t, 2u * k + 0u, T_curr, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in); + var p_ry: BigInt = load_in(1u, t, 2u * k + 1u, T_curr, N_in); + + var inv_dx: BigInt; + if (k == 0u) { + inv_dx = inv; + } else { + var pp = pref[k - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda: BigInt = fr_sub(&p_ry, &p_ly); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_lx); + r_x = fr_sub(&r_x, &p_rx); + var r_y: BigInt = fr_sub(&p_lx, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &p_ly); + + store_out_simple(0u, t, k, T_curr, N_out, &r_x); + store_out_simple(1u, t, k, T_curr, N_out, &r_y); + + if (k > 0u) { + var dx_back: BigInt = fr_sub(&p_rx, &p_lx); + inv = montgomery_product(&inv, &dx_back); + } + } + + {{{ recompile }}} +} +`; + +export const ba_planner_bench = `{{> structs }} + +// GPU-side bin-packing planner for the v3 MSM bucket-accumulate +// pipeline. One thread per bucket; uses atomicAdd to reserve global +// per-pair slots in chunk_plan / scatter_plan and per-carry slots in +// carry_plan, then writes that bucket's entries. +// +// Inputs (per current level): +// counts: array per-bucket active count +// offsets: array per-bucket starting index in active_sums_old +// +// Outputs (filled in by this kernel for the current level): +// chunk_plan: array 2 u32 per (chunk_id, slot) — pair operand indices +// scatter_plan: array 1 u32 per (chunk_id, slot) — destination in active_sums_new +// carry_plan: array 2 u32 per carry slot — (src in old, dst in new) +// totals: array> [0]=total pairs, [1]=total carries, [2]=total new actives +// new_counts: array per-bucket new active count (for next level) +// new_offsets: array per-bucket new offset in active_sums_new (for next level) +// +// Convention: discard slot = M_new - 1 (the highest index in +// active_sums_new). Pad pair source indices = (pad_l_idx, pad_r_idx) +// supplied via params. All non-real chunk_plan / scatter_plan slots +// must be pre-padded to (pad_l_idx, pad_r_idx) and discard_idx by the +// host before each planner dispatch. +// +// params.x = B (bucket count) +// params.y = S (chunk size, slots per chunk) +// (pad_l_idx / pad_r_idx / discard_idx live in the pre-padded +// arrays, not in params) + +const S: u32 = {{ s }}u; + +@group(0) @binding(0) var counts: array; +@group(0) @binding(1) var offsets: array; +@group(0) @binding(2) var chunk_plan: array; +@group(0) @binding(3) var scatter_plan: array; +@group(0) @binding(4) var carry_plan: array; +@group(0) @binding(5) var totals: array>; +@group(0) @binding(6) var new_counts: array; +@group(0) @binding(7) var new_offsets: array; +@group(0) @binding(8) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let B = params.x; + let b = gid.x; + if (b >= B) { return; } + + let n = counts[b]; + let pair_count = n / 2u; + let carry_flag = n & 1u; + let nc = pair_count + carry_flag; + new_counts[b] = nc; + + // Atomic offset reservation. Each bucket gets a unique non-overlapping + // range in the global arrays. Atomic order is non-deterministic but + // that's fine: bucket b records its assigned offsets and uses them + // consistently for its own chunk_plan / scatter_plan / new_offsets + // writes. Different buckets land in different ranges by construction. + let my_pair_off = atomicAdd(&totals[0u], pair_count); + let my_carry_off = atomicAdd(&totals[1u], carry_flag); + let my_new_off = atomicAdd(&totals[2u], nc); + new_offsets[b] = my_new_off; + + let bucket_base = offsets[b]; + + // Write this bucket's pair entries into chunk_plan / scatter_plan. + // Loop bounded by pair_count (variable per bucket; typically ~16 + // for Poisson(λ=32)). The TAIL_CAP-style compile-time bound used + // by ba_tail_reduce isn't strictly needed here since this kernel + // doesn't do field arithmetic; the loop is plain integer writes. + // We still bound it by a compile-time constant for WGSL static + // analysis purposes. + let PAIR_CAP: u32 = {{ pair_cap }}u; + for (var j: u32 = 0u; j < PAIR_CAP; j = j + 1u) { + if (j >= pair_count) { break; } + let global_slot = my_pair_off + j; + let chunk_id = global_slot / S; + let slot_in_chunk = global_slot % S; + let cp_base = 2u * (chunk_id * S + slot_in_chunk); + chunk_plan[cp_base + 0u] = bucket_base + 2u * j; + chunk_plan[cp_base + 1u] = bucket_base + 2u * j + 1u; + scatter_plan[chunk_id * S + slot_in_chunk] = my_new_off + j; + } + + if (carry_flag != 0u) { + let cs = my_carry_off; + carry_plan[2u * cs + 0u] = bucket_base + n - 1u; + carry_plan[2u * cs + 1u] = my_new_off + pair_count; + } + + {{{ recompile }}} +} +`; + +export const ba_planner_v2_bench = `{{> structs }} + +// Optimal single-kernel GPU bin-packing planner for the MSM +// bucket-accumulate pair-tree. +// +// One workgroup of TPB threads processes B buckets. Each thread +// handles PER_THREAD = B / TPB buckets via a contiguous slice +// [tid * PER_THREAD, (tid+1) * PER_THREAD). +// +// Phase A — Per-thread local scan +// For each of its PER_THREAD buckets, compute (pair_count, carry_flag, +// new_count). Accumulate per-thread totals (sum across the thread's +// slice). Keep the per-bucket triples in registers; we will re-scan +// them in Phase B. +// +// Phase B — Workgroup-wide Hillis-Steele scan (3 in parallel) +// Scan the per-thread totals for pair, carry, new across the TPB +// threads in shared memory. Result: each thread gets the global +// prefix sum at the START of its slice (= base offset for its first +// bucket). +// +// Phase C — Per-thread scatter +// For each bucket in the thread's slice (in order), use the running +// thread-local offset to compute global pair_offset_b and write the +// pair_count[b] chunk_plan entries plus the (optional) carry_plan +// entry. Update local running offsets. Write new_counts[b] and +// new_offsets[b] for the next level. +// +// Phase D — One thread writes totals. +// totals[0] = total_pairs, totals[1] = total_carries, +// totals[2] = total_new_actives. +// +// Single dispatch. No atomics. No host sync. Scales to B = TPB * +// PER_THREAD (e.g. 256 * 32 = 8192) within one workgroup. Larger B +// requires multi-workgroup scan + global combine (out of scope here). +// +// Compile-time constants: +// TPB : workgroup size (e.g. 256) +// PER_THREAD : buckets per thread (e.g. 16 for B=4096, 32 for B=8192) +// PAIR_CAP : bound on per-bucket pair count (Poisson(λ=32) tail +// is ~30; choose 64 for safety) +// S : chunk size in pairs (e.g. 16) + +const TPB: u32 = {{ workgroup_size }}u; +const PER_THREAD: u32 = {{ per_thread }}u; +const PAIR_CAP: u32 = {{ pair_cap }}u; +const S: u32 = {{ s }}u; + +@group(0) @binding(0) var counts: array; +@group(0) @binding(1) var offsets: array; +@group(0) @binding(2) var chunk_plan: array; +@group(0) @binding(3) var scatter_plan: array; +@group(0) @binding(4) var carry_plan: array; +@group(0) @binding(5) var new_counts: array; +@group(0) @binding(6) var new_offsets: array; +@group(0) @binding(7) var totals: array; +@group(0) @binding(8) var params: vec4; +// params.x = B + +// Workgroup-shared running prefixes for the 3 scans. +var pair_scan: array; +var carry_scan: array; +var new_scan: array; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; + let B = params.x; + + // Phase A: per-thread local read + accumulate. + // Keep PER_THREAD bucket triples in registers (small array). + var local_pc: array; + var local_cf: array; + var local_nc: array; + var sum_p: u32 = 0u; + var sum_c: u32 = 0u; + var sum_n: u32 = 0u; + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tid * PER_THREAD + k; + var pc: u32 = 0u; + var cf: u32 = 0u; + var nc: u32 = 0u; + if (b < B) { + let n = counts[b]; + pc = n / 2u; + cf = n & 1u; + nc = pc + cf; + } + local_pc[k] = pc; + local_cf[k] = cf; + local_nc[k] = nc; + sum_p += pc; + sum_c += cf; + sum_n += nc; + } + + // Phase B: workgroup-wide Hillis-Steele inclusive scan over per- + // thread totals (3 scans interleaved). + pair_scan[tid] = sum_p; + carry_scan[tid] = sum_c; + new_scan[tid] = sum_n; + workgroupBarrier(); + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var add_p: u32 = 0u; + var add_c: u32 = 0u; + var add_n: u32 = 0u; + if (tid >= stride) { + add_p = pair_scan[tid - stride]; + add_c = carry_scan[tid - stride]; + add_n = new_scan[tid - stride]; + } + workgroupBarrier(); + if (tid >= stride) { + pair_scan[tid] = pair_scan[tid] + add_p; + carry_scan[tid] = carry_scan[tid] + add_c; + new_scan[tid] = new_scan[tid] + add_n; + } + workgroupBarrier(); + } + // pair_scan[tid] is now inclusive prefix. Exclusive base = inclusive - own_sum. + var local_pair_off: u32 = pair_scan[tid] - sum_p; + var local_carry_off: u32 = carry_scan[tid] - sum_c; + var local_new_off: u32 = new_scan[tid] - sum_n; + + // Phase D: thread 0 writes totals (using the FINAL inclusive scan). + if (tid == TPB - 1u) { + totals[0] = pair_scan[tid]; + totals[1] = carry_scan[tid]; + totals[2] = new_scan[tid]; + } + + // Phase C: per-thread scatter. + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tid * PER_THREAD + k; + if (b >= B) { break; } + + let pc = local_pc[k]; + let cf = local_cf[k]; + let nc = local_nc[k]; + new_counts[b] = nc; + new_offsets[b] = local_new_off; + + let bucket_base = offsets[b]; + + // Pair entries: bounded loop, break at pc. + for (var j: u32 = 0u; j < PAIR_CAP; j = j + 1u) { + if (j >= pc) { break; } + let global_slot = local_pair_off + j; + let chunk_id = global_slot / S; + let slot_in_chunk = global_slot % S; + let cp_base = 2u * (chunk_id * S + slot_in_chunk); + chunk_plan[cp_base + 0u] = bucket_base + 2u * j; + chunk_plan[cp_base + 1u] = bucket_base + 2u * j + 1u; + scatter_plan[chunk_id * S + slot_in_chunk] = local_new_off + j; + } + + // Carry entry (if odd count). + if (cf != 0u) { + let cs = local_carry_off; + carry_plan[2u * cs + 0u] = bucket_base + counts[b] - 1u; + carry_plan[2u * cs + 1u] = local_new_off + pc; + } + + local_pair_off += pc; + local_carry_off += cf; + local_new_off += nc; + } + + {{{ recompile }}} +} +`; + +export const ba_planner_v2_prod = `{{> structs }} + +// Production GPU bin-packing planner for the v2 pair-tree integration. +// +// Same algorithm as ba_planner_v2_bench (one workgroup of TPB threads, +// per-thread local scan, workgroup-wide Hillis-Steele scan over the +// three running sums, per-thread scatter) but extends the totals +// output with the indirect-dispatch counts the production marshal / +// disjoint / scatter / carry kernels need: +// +// totals[0] = total_pairs +// totals[1] = total_carries +// totals[2] = total_new +// totals[3] = num_chunks = max(1, (total_pairs + S - 1) / S) +// totals[4] = marshal/disjoint/scatter dispatch X (= ceil(num_chunks / WGI)) +// totals[5] = 1 +// totals[6] = 1 +// totals[7] = carry dispatch X (= ceil(total_carries / WGI)) +// totals[8] = 1 +// totals[9] = 1 +// +// The four prod-variant downstream kernels (ba_marshal_pairs_prod, +// ba_pair_disjoint_tree_prod, ba_scatter_pairs_prod, ba_carry_copy_prod) +// read num_chunks and total_carries from this same totals storage +// buffer so a single planner dispatch fully drives the level's runtime +// shape with zero wasted-pad-chunk compute. The host orchestrator +// reuses the totals buffer as the indirect-dispatch source via +// dispatchWorkgroupsIndirect(totals, 16) for marshal/disjoint/scatter +// (totals u32 indices 4..6) and dispatchWorkgroupsIndirect(totals, 28) +// for carry (totals u32 indices 7..9). +// +// Compile-time constants: +// TPB : workgroup size (e.g. 256) +// PER_THREAD : buckets per thread +// PAIR_CAP : per-bucket pair-count bound +// S : chunk size in pairs +// WGI : downstream kernel workgroup size — must match the +// workgroup_size of ba_marshal_pairs_prod / +// ba_pair_disjoint_tree_prod / ba_scatter_pairs_prod / +// ba_carry_copy_prod. + +const TPB: u32 = {{ workgroup_size }}u; +const PER_THREAD: u32 = {{ per_thread }}u; +const PAIR_CAP: u32 = {{ pair_cap }}u; +const S: u32 = {{ s }}u; +const WGI: u32 = {{ wgi }}u; + +@group(0) @binding(0) var counts: array; +@group(0) @binding(1) var offsets: array; +@group(0) @binding(2) var chunk_plan: array; +@group(0) @binding(3) var scatter_plan: array; +@group(0) @binding(4) var carry_plan: array; +@group(0) @binding(5) var new_counts: array; +@group(0) @binding(6) var new_offsets: array; +@group(0) @binding(7) var totals: array; +@group(0) @binding(8) var params: vec4; +// params.x = B +// params.y = pad_left_idx (active_sums index used for chunk_plan tail pad left operand) +// params.z = pad_right_idx (chunk_plan tail pad right operand; must differ from pad_left_idx in x) +// params.w = discard_idx (scatter_plan tail dst; slot that the next level never reads) + +var pair_scan: array; +var carry_scan: array; +var new_scan: array; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; + let B = params.x; + + var local_pc: array; + var local_cf: array; + var local_nc: array; + var sum_p: u32 = 0u; + var sum_c: u32 = 0u; + var sum_n: u32 = 0u; + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tid * PER_THREAD + k; + var pc: u32 = 0u; + var cf: u32 = 0u; + var nc: u32 = 0u; + if (b < B) { + let n = counts[b]; + pc = n / 2u; + cf = n & 1u; + nc = pc + cf; + } + local_pc[k] = pc; + local_cf[k] = cf; + local_nc[k] = nc; + sum_p += pc; + sum_c += cf; + sum_n += nc; + } + + pair_scan[tid] = sum_p; + carry_scan[tid] = sum_c; + new_scan[tid] = sum_n; + workgroupBarrier(); + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var add_p: u32 = 0u; + var add_c: u32 = 0u; + var add_n: u32 = 0u; + if (tid >= stride) { + add_p = pair_scan[tid - stride]; + add_c = carry_scan[tid - stride]; + add_n = new_scan[tid - stride]; + } + workgroupBarrier(); + if (tid >= stride) { + pair_scan[tid] = pair_scan[tid] + add_p; + carry_scan[tid] = carry_scan[tid] + add_c; + new_scan[tid] = new_scan[tid] + add_n; + } + workgroupBarrier(); + } + var local_pair_off: u32 = pair_scan[tid] - sum_p; + var local_carry_off: u32 = carry_scan[tid] - sum_c; + var local_new_off: u32 = new_scan[tid] - sum_n; + + if (tid == TPB - 1u) { + let tp = pair_scan[tid]; + let tc = carry_scan[tid]; + let tn = new_scan[tid]; + totals[0] = tp; + totals[1] = tc; + totals[2] = tn; + let num_chunks = (tp + S - 1u) / S; + totals[3] = num_chunks; + totals[4] = (num_chunks + WGI - 1u) / WGI; + totals[5] = 1u; + totals[6] = 1u; + totals[7] = (tc + WGI - 1u) / WGI; + totals[8] = 1u; + totals[9] = 1u; + } + + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tid * PER_THREAD + k; + if (b >= B) { break; } + + let pc = local_pc[k]; + let cf = local_cf[k]; + let nc = local_nc[k]; + new_counts[b] = nc; + new_offsets[b] = local_new_off; + + let bucket_base = offsets[b]; + + for (var j: u32 = 0u; j < PAIR_CAP; j = j + 1u) { + if (j >= pc) { break; } + let global_slot = local_pair_off + j; + let chunk_id = global_slot / S; + let slot_in_chunk = global_slot % S; + let cp_base = 2u * (chunk_id * S + slot_in_chunk); + chunk_plan[cp_base + 0u] = bucket_base + 2u * j; + chunk_plan[cp_base + 1u] = bucket_base + 2u * j + 1u; + scatter_plan[chunk_id * S + slot_in_chunk] = local_new_off + j; + } + + if (cf != 0u) { + let cs = local_carry_off; + carry_plan[2u * cs + 0u] = bucket_base + counts[b] - 1u; + carry_plan[2u * cs + 1u] = local_new_off + pc; + } + + local_pair_off += pc; + local_carry_off += cf; + local_new_off += nc; + } + + workgroupBarrier(); + if (tid == TPB - 1u) { + let tp = pair_scan[tid]; + let num_chunks = (tp + S - 1u) / S; + let pad_end = num_chunks * S; + let pad_left = params.y; + let pad_right = params.z; + let discard_idx = params.w; + for (var i: u32 = tp; i < pad_end; i = i + 1u) { + chunk_plan[2u * i + 0u] = pad_left; + chunk_plan[2u * i + 1u] = pad_right; + scatter_plan[i] = discard_idx; + } + } + + {{{ recompile }}} +} +`; + +export const ba_rev_packed_carry_bench = `{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +// MSM-integrated bucket-accumulate batch-affine kernel — packed 8x u32 +// storage + decoupled (full-ILP) pack/unpack + reversed direction + +// resident-accumulator load-carry. Drives the canonical +// ba_rev_packed_carry benchmark that reached ~22 ns/pair on M2 / Chrome +// 148 (-55% vs the production batch-affine kernel). +// +// Math is byte-identical to ba_msm_bucket_bench: forward running +// prefix-product of the S dx values in a private array, ONE +// fr_inv_by_a per chunk of S, backward peel with the lean affine +// formula (dx recomputed free in the backward pass), resident +// accumulator A.x kept in registers across the whole chunk (load-carry: +// A_{i+1} := P_i so the forward and backward passes share one global +// P_i.x load per iteration). Same Karatsuba+Yuval montmul and BY-safegcd +// fr_inv_by_a as the production stack. +// +// The single structural change from ba_msm_bucket_bench: +// global storage is the packed 254-bit value stored as 8x u32 +// (32 bytes/elem == 2x vec4), not the 20x 13-bit-limb BigInt +// (80 bytes/elem == 5x vec4). Unpack into 20x13-bit limbs only +// in-register at load and repack on store. The pack/unpack is the +// decoupled full-ILP straight-line form (injected below as +// unpack256_to_limbs / pack_limbs_to_256): 20 mutually-independent +// compile-time-constant-indexed limb extractions, zero loop-carried +// bit-cursor dependency chain. This cuts global traffic 2.5x (the +// dominant cost in the memory-bound batch-affine kernel) at a +// sub-cycle in-register cost. +// +// LAYOUT: packed elem = 2 vec4; for each of the 4 input planes +// (A.x, A.y, P.x, P.y) and 2 output planes (R.x, R.y), plane c holds +// N elements at indices c*2*N + 2*e + {0,1}. params.x = N (total +// point-adds), params.y = T (thread count = N/S). +// +// Thread t streams points e = t + i*T for i in 0..S (strided => fully +// coalesced across the apply phase). The "left" operand of add i is the +// running accumulator A_i; A_0 is the per-thread seed (plane 0/1 at +// e=t), A_{i+1} := P_i (load-carry; same global address as forward +// pass's P_i load, no extra global traffic). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; // 8 u32 packed limbs / 4 = 2 vec4 groups + +@group(0) @binding(0) var inp: array>; +@group(0) @binding(1) var unused: array>; +@group(0) @binding(2) var outp: array>; +@group(0) @binding(3) var params: vec4; + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +fn load_be_packed(plane_base: u32, e: u32, N: u32) -> BigInt { + // plane_base is in vec4 units; per plane: 2*N vec4 (PG=2). + let base = plane_base + PG * e; + let q0 = inp[base + 0u]; + let q1 = inp[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_be_packed(plane_base: u32, e: u32, N: u32, val: ptr) { + let w = pack_limbs_to_256(val); + let base = plane_base + PG * e; + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let N = params.x; + let T = params.y; + let t = gid.x; + if (t >= T) { return; } + + // Plane bases in vec4 units. Each plane spans PG*N vec4. + let plane = PG * N; + let ax_base = 0u * plane; + let ay_base = 1u * plane; + let px_base = 2u * plane; + let py_base = 3u * plane; + + // Resident accumulator A.x stays in registers across the whole + // chunk (drives the forward dx prefix chain). A.y is only needed in + // the backward peel and is re-loaded there from the same SoA plane. + var acc_x = load_be_packed(ax_base, t, N); + + // Forward pass: running prefix-product of the S dx values + // dx_i = P_i.x - A_i.x. A_i is the prefix accumulator (resident). + var pref: array; + var acc: BigInt = get_r(); + for (var i = 0u; i < S; i = i + 1u) { + let e = t + i * T; + var p_x = load_be_packed(px_base, e, N); + var dx = fr_sub(&p_x, &acc_x); + if (i == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[i] = acc; + // Resident accumulator advances along the streamed chain: + // A_0 is the seed, A_{i+1} := P_i. Points are independent + // (P_i.x != A_i.x) so every dx is a well-defined nonzero + // difference. inv_dx is deferred to the backward pass (ONE + // fr_inv_by_a per chunk of S); A stays in registers throughout. + acc_x = p_x; + } + + var inv: BigInt = fr_inv_by_a(acc); + + // Backward peel + lean affine formula (dx recomputed free). + for (var jj = 0u; jj < S; jj = jj + 1u) { + let i = S - 1u - jj; + let e = t + i * T; + var p_x = load_be_packed(px_base, e, N); + var p_y = load_be_packed(py_base, e, N); + + // A_i (left operand): A_0 is the seed, A_i = P_{i-1} for i>0 + // (matches the forward acc_x recurrence; points independent so + // dx = P_i.x - A_i.x is always well-defined and nonzero). + var a_x: BigInt; + var a_y: BigInt; + if (i == 0u) { + a_x = load_be_packed(ax_base, t, N); + a_y = load_be_packed(ay_base, t, N); + } else { + let ep = t + (i - 1u) * T; + a_x = load_be_packed(px_base, ep, N); + a_y = load_be_packed(py_base, ep, N); + } + + var inv_dx: BigInt; + if (i == 0u) { + inv_dx = inv; + } else { + var pp = pref[i - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda = fr_sub(&p_y, &a_y); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &a_x); + r_x = fr_sub(&r_x, &p_x); + var r_y = fr_sub(&a_x, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &a_y); + + store_be_packed(0u * plane, e, N, &r_x); + store_be_packed(1u * plane, e, N, &r_y); + + if (i != 0u) { + var dx_back = fr_sub(&p_x, &a_x); + inv = montgomery_product(&inv, &dx_back); + } + } +} +`; + +export const ba_scatter_pairs_bench = `{{> structs }} + +// Scatter kernel for the bin-packed pair-tree MSM bucket-accumulate. +// +// For each (chunk t, slot k), reads R.x/R.y from the disjoint kernel's +// strided output (where it landed at flat index t + k * T after +// running with final_flag=1) and writes them to active_sums_new at +// the destination index given by scatter_plan[t * S + k]. +// +// This is the per-bucket-placement pass that re-groups pair sums for +// the next level's bin-packing planner. +// +// scatter_plan layout: 1 u32 per (chunk, slot). +// scatter_plan[t * S + k] = dst_idx (active_sums_new index) +// +// disjoint_out layout: 2 planes (R.x, R.y), PG=2 vec4 per element, +// S * T elements per plane (matches the disjoint kernel's +// final-mode simple strided write). +// +// active_sums_new layout: 2 planes (P.x, P.y), PG=2 vec4 per element, +// M_new elements per plane (params.y). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var scatter_plan: array; +@group(0) @binding(1) var disjoint_out: array>; +@group(0) @binding(2) var active_sums_new: array>; +@group(0) @binding(3) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let M_new = params.y; + let t = gid.x; + if (t >= T) { return; } + + let out_N = S * T; + let out_plane_x = 0u * PG * out_N; + let out_plane_y = 1u * PG * out_N; + + let new_plane_x = 0u * PG * M_new; + let new_plane_y = 1u * PG * M_new; + + for (var k: u32 = 0u; k < S; k = k + 1u) { + let e = t + k * T; + let dst_idx = scatter_plan[t * S + k]; + + let src_x = out_plane_x + PG * e; + let src_y = out_plane_y + PG * e; + let dst_x = new_plane_x + PG * dst_idx; + let dst_y = new_plane_y + PG * dst_idx; + + active_sums_new[dst_x + 0u] = disjoint_out[src_x + 0u]; + active_sums_new[dst_x + 1u] = disjoint_out[src_x + 1u]; + active_sums_new[dst_y + 0u] = disjoint_out[src_y + 0u]; + active_sums_new[dst_y + 1u] = disjoint_out[src_y + 1u]; + } + + {{{ recompile }}} +} +`; + +export const ba_scatter_pairs_prod = `{{> structs }} + +// Scatter kernel — prod variant for the v2 pair-tree integration. +// Same per-bucket placement math as ba_scatter_pairs_bench; T is read +// from the planner's totals[3] and the dispatch is indirect via +// totals[4..6]. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var scatter_plan: array; +@group(0) @binding(1) var disjoint_out: array>; +@group(0) @binding(2) var active_sums_new: array>; +@group(0) @binding(3) var totals: array; +@group(0) @binding(4) var consts: vec4; +// consts.x = M_new + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = totals[3]; + let M_new = consts.x; + let t = gid.x; + if (t >= T) { return; } + + let out_N = S * T; + let out_plane_x = 0u * PG * out_N; + let out_plane_y = 1u * PG * out_N; + + let new_plane_x = 0u * PG * M_new; + let new_plane_y = 1u * PG * M_new; + + for (var k: u32 = 0u; k < S; k = k + 1u) { + let e = t + k * T; + let dst_idx = scatter_plan[t * S + k]; + + let src_x = out_plane_x + PG * e; + let src_y = out_plane_y + PG * e; + let dst_x = new_plane_x + PG * dst_idx; + let dst_y = new_plane_y + PG * dst_idx; + + active_sums_new[dst_x + 0u] = disjoint_out[src_x + 0u]; + active_sums_new[dst_x + 1u] = disjoint_out[src_x + 1u]; + active_sums_new[dst_y + 0u] = disjoint_out[src_y + 0u]; + active_sums_new[dst_y + 1u] = disjoint_out[src_y + 1u]; + } + + {{{ recompile }}} +} +`; + +export const ba_tail_reduce_bench = `{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Tail kernel for the bench-msm-tree pipeline: reduces a single +// tail-sized bucket (count < 2*S) to one sum per thread. Each thread +// reads its bucket's count points sequentially from the SoA-packed +// point pool and accumulates them via direct affine adds (one +// fr_inv_by_a per step). +// +// Pragmatic v1 — no batched inversion across threads. Each step pays +// one full fr_inv_by_a (~80 mont mul equivalents). For typical +// Poisson(lambda=16) MSM workloads, tail buckets carry a minority of +// total work (~10-30%); the contribution to overall bucket-accumulate +// ns/in-pt is small enough that this simple design is acceptable for +// a v1 complete-replacement kernel set. A workgroup-scan +// batched-inversion variant is a follow-on optimisation that would +// drop tail cost to ~25 ns/add (matching the main pair-tree). +// +// Bindings: +// binding 0: csr_indices — sorted point indices, 1-based (index 0 reserved). +// binding 1: tail_plan — three u32 per tail thread: +// [bucket_id, csr_start, count]. +// binding 2: point_pool — SoA-packed pool (2 planes, PG=2 vec4/elem). +// binding 3: bucket_sums — SoA-packed output (2 planes, PG=2 vec4/bucket), +// one packed point per bucket. Pre-zeroed by host. +// binding 4: params — params.x=T (tail thread count), +// params.y=N (pool size), +// params.z=B (bucket_sums slot count). +// +// Bounded loop: the per-thread accumulate loop iterates up to compile- +// time TAIL_CAP = 2*S - 1, breaking early when i >= count. No +// data-dependent unbounded loops. + +const TAIL_CAP: u32 = {{ tail_cap }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var csr_indices: array; +@group(0) @binding(1) var tail_plan: array; +@group(0) @binding(2) var point_pool: array>; +@group(0) @binding(3) var bucket_sums: array>; +@group(0) @binding(4) var params: vec4; + +fn load_pool(plane: u32, idx: u32, N: u32) -> BigInt { + let plane_base = plane * PG * N; + let base = plane_base + PG * idx; + let q0 = point_pool[base + 0u]; + let q1 = point_pool[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_bucket(plane: u32, b: u32, B: u32, val: ptr) { + let plane_base = plane * PG * B; + let base = plane_base + PG * b; + let w = pack_limbs_to_256(val); + bucket_sums[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + bucket_sums[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let N = params.y; + let B = params.z; + + let t = gid.x; + if (t >= T) { return; } + + let bucket_id = tail_plan[3u * t + 0u]; + let csr_start = tail_plan[3u * t + 1u]; + let count = tail_plan[3u * t + 2u]; + + if (count == 0u) { return; } + + var acc_x: BigInt = load_pool(0u, csr_indices[csr_start], N); + var acc_y: BigInt = load_pool(1u, csr_indices[csr_start], N); + + for (var i: u32 = 1u; i < TAIL_CAP; i = i + 1u) { + if (i >= count) { break; } + let pt_idx = csr_indices[csr_start + i]; + var p_x: BigInt = load_pool(0u, pt_idx, N); + var p_y: BigInt = load_pool(1u, pt_idx, N); + var dx: BigInt = fr_sub(&p_x, &acc_x); + var inv_dx: BigInt = fr_inv_by_a(dx); + var dy: BigInt = fr_sub(&p_y, &acc_y); + var lambda: BigInt = montgomery_product(&dy, &inv_dx); + var lambda_sq: BigInt = montgomery_product(&lambda, &lambda); + var r_x: BigInt = fr_sub(&lambda_sq, &acc_x); + r_x = fr_sub(&r_x, &p_x); + var r_y: BigInt = fr_sub(&acc_x, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &acc_y); + acc_x = r_x; + acc_y = r_y; + } + + store_bucket(0u, bucket_id, B, &acc_x); + store_bucket(1u, bucket_id, B, &acc_y); + + {{{ recompile }}} +} +`; + export const barrett = `const W_MASK = {{ w_mask }}u; const SLACK = {{ slack }}u; @@ -2660,6 +4524,230 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { } `; +export const batch_affine_fused_wg_scan = `{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +{{> packed_field_funcs }} + +// Workgroup-scan fused batch-affine round kernel for v2 MSM. +// +// Mirrors \`bench_batch_affine.template.wgsl\`'s phases A/B/C/D — TPB +// threads cooperating on BATCH_SIZE = TPB*BS pairs per workgroup with +// one fr_inv_by_a per workgroup — adapted for the MSM pipeline: +// - storage is packed 8×u32 per field element (vs the bench's +// BigInt-array storage); conversions happen only at field_load_* +// and field_store, every kernel-local var holds BigInt limbs +// - loads are bucket-indirect via \`pair_target_meta\` (vs the bench's +// flat \`inputs[pair_base + *]\`) +// +// PHASES +// A) Per-thread serial prefix product over BS pairs. Each thread +// writes its prefix-product chain to \`prefix[batch_base + k]\` +// (global storage) and captures \`block_total\` in a register. +// B) Workgroup-shared Hillis-Steele forward + backward scan over the +// TPB block_totals (log2 TPB rounds of mont mul). +// C) Thread 0 inverts the global product via fr_inv_by_a (ONE per +// workgroup). Broadcasts to wg_inv_total. +// D) Each thread back-walks its chunk, recovers inv_dx for each pair +// from (wg_inv_total * block_excl_prefix * block_excl_suffix * +// prev_in_chunk_prefix), emits lean affine add, scatters to +// running_x/y[bucket]. +// +// SAFETY +// The scheduler emits at most one pair per (subtask, bucket) per +// round. Within a workgroup's BATCH_SIZE slots, every \`bucket\` is +// distinct → no intra-workgroup RAW hazard on the running_x/y +// scatters. Across workgroups in the same subtask: disjoint slot +// ranges → still distinct buckets. Across subtasks (Z dim): different +// bucket ranges entirely. +// +// DISPATCH +// workgroup_size = TPB. Workgroups in X = ceil(n / (TPB*BS)). +// Workgroups in Z = num_subtasks. The atomicLoad of count_buf and +// subsequent control flow are uniform within a workgroup (every +// thread sees the same \`n\`), but Tint can't prove that — so we never +// early-return based on it. Instead, partial-batch threads contribute +// identity to the scan and skip their work loop bodies. + +const TPB: u32 = {{ tpb }}u; +const BS: u32 = {{ bs }}u; +const BATCH_SIZE: u32 = {{ batch_size }}u; + +@group(0) @binding(0) +var val_idx: array; +@group(0) @binding(1) +var new_point_x: array>; +@group(0) @binding(2) +var new_point_y: array>; +@group(0) @binding(3) +var running_x: array>; +@group(0) @binding(4) +var running_y: array>; +@group(0) @binding(5) +var pair_target_meta: array; +@group(0) @binding(6) +var prefix_buf: array; +@group(0) @binding(7) +var count_buf: array>; + +// params[0] = num_columns (per-subtask pool stride) +// params[1] = input_size (per-subtask val_idx stride) +@group(0) @binding(8) +var params: vec4; + +var wg_fwd: array; +var wg_bwd: array; +var wg_inv_total: BigInt; + +@compute +@workgroup_size({{ tpb }}) +fn main( + @builtin(local_invocation_id) lid: vec3, + @builtin(workgroup_id) wid: vec3, +) { + let tid = lid.x; + let wg_idx = wid.x; + let subtask_idx = wid.z; + let num_columns = params[0]; + let input_size = params[1]; + + let n = atomicLoad(&count_buf[subtask_idx]); + let batch_base = wg_idx * BATCH_SIZE; + + let pool_base = subtask_idx * num_columns; + let vi_offset = subtask_idx * input_size; + + let chunk_start = tid * BS; + let chunk_pool_base = pool_base + batch_base + chunk_start; + + let in_pool = batch_base + chunk_start + BS <= n; + + // Phase A — per-thread serial prefix product. Inin_pool threads + // (chunk past the live pool) contribute identity (R = Mont 1) so + // the workgroup scan reads a sane value at every slot. + var block_total: BigInt = get_r(); + if (in_pool) { + { + let k0 = 0u; + let slot = chunk_pool_base + k0; + let bucket = pair_target_meta[2u * slot]; + let q_cursor = pair_target_meta[2u * slot + 1u]; + let pt_idx = val_idx[vi_offset + q_cursor]; + var p_x: BigInt = field_load_rw(bucket, &running_x); + var q_x: BigInt = field_load_ro(pt_idx, &new_point_x); + var dx: BigInt = fr_sub(&q_x, &p_x); + prefix_buf[chunk_pool_base + k0] = dx; + block_total = dx; + } + for (var i: u32 = 1u; i < BS; i = i + 1u) { + let slot = chunk_pool_base + i; + let bucket = pair_target_meta[2u * slot]; + let q_cursor = pair_target_meta[2u * slot + 1u]; + let pt_idx = val_idx[vi_offset + q_cursor]; + var p_x: BigInt = field_load_rw(bucket, &running_x); + var q_x: BigInt = field_load_ro(pt_idx, &new_point_x); + var dx: BigInt = fr_sub(&q_x, &p_x); + block_total = montgomery_product(&block_total, &dx); + prefix_buf[chunk_pool_base + i] = block_total; + } + } + + wg_fwd[tid] = block_total; + wg_bwd[tid] = block_total; + workgroupBarrier(); + + // Phase B — Hillis-Steele forward + backward inclusive scan. + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var fwd_x: BigInt = wg_fwd[tid]; + if (tid >= stride) { + var lhs: BigInt = wg_fwd[tid - stride]; + fwd_x = montgomery_product(&lhs, &fwd_x); + } + var bwd_x: BigInt = wg_bwd[tid]; + if (tid + stride < TPB) { + var rhs: BigInt = wg_bwd[tid + stride]; + bwd_x = montgomery_product(&bwd_x, &rhs); + } + workgroupBarrier(); + wg_fwd[tid] = fwd_x; + wg_bwd[tid] = bwd_x; + workgroupBarrier(); + } + + // Phase C — single fr_inv per workgroup. + if (tid == 0u) { + var global_total: BigInt = wg_fwd[TPB - 1u]; + wg_inv_total = fr_inv_by_a(global_total); + } + workgroupBarrier(); + + // Phase D — back-walk this thread's chunk, emit lean affine adds. + if (!in_pool) { + return; + } + var block_excl_prefix: BigInt = get_r(); + if (tid > 0u) { + block_excl_prefix = wg_fwd[tid - 1u]; + } + var block_excl_suffix: BigInt = get_r(); + if (tid + 1u < TPB) { + block_excl_suffix = wg_bwd[tid + 1u]; + } + var inv_global: BigInt = wg_inv_total; + var inv_acc: BigInt = montgomery_product(&inv_global, &block_excl_prefix); + inv_acc = montgomery_product(&inv_acc, &block_excl_suffix); + + for (var off: u32 = 0u; off < BS; off = off + 1u) { + let k = BS - 1u - off; + let slot = chunk_pool_base + k; + let bucket = pair_target_meta[2u * slot]; + let q_cursor = pair_target_meta[2u * slot + 1u]; + let pt_idx = val_idx[vi_offset + q_cursor]; + + var p_x: BigInt = field_load_rw(bucket, &running_x); + var p_y: BigInt = field_load_rw(bucket, &running_y); + var q_x: BigInt = field_load_ro(pt_idx, &new_point_x); + var q_y: BigInt = field_load_ro(pt_idx, &new_point_y); + + var inv_dx: BigInt; + if (k > 0u) { + var prev_prefix: BigInt = prefix_buf[chunk_pool_base + (k - 1u)]; + inv_dx = montgomery_product(&inv_acc, &prev_prefix); + } else { + inv_dx = inv_acc; + } + + var dy: BigInt = fr_sub(&q_y, &p_y); + var lambda: BigInt = montgomery_product(&dy, &inv_dx); + var lambda_sq: BigInt = montgomery_product(&lambda, &lambda); + var t1: BigInt = fr_sub(&lambda_sq, &p_x); + var r_x: BigInt = fr_sub(&t1, &q_x); + var dx_back: BigInt = fr_sub(&p_x, &r_x); + var ldx: BigInt = montgomery_product(&lambda, &dx_back); + var r_y: BigInt = fr_sub(&ldx, &p_y); + + field_store(bucket, &running_x, &r_x); + field_store(bucket, &running_y, &r_y); + + if (k > 0u) { + var dx_k: BigInt = fr_sub(&q_x, &p_x); + inv_acc = montgomery_product(&inv_acc, &dx_k); + } + } + + {{{ recompile }}} +} +`; + export const batch_affine_init = `{{> structs }} // Init kernel for the batch-affine SMVP pipeline. @@ -4667,6 +6755,126 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { } `; +export const csr_to_v2_active_sums = `// Layout converter for the v2 pair-tree MSM bucket-accumulate path. +// +// Materializes the bucket-major active_sums buffer by copying packed +// 8×u32 base coords from the cached_bases (new_point_x / new_point_y) +// at the indices listed in val_idx (cuZK transpose output, bucket-major +// per subtask). +// +// active_sums is one combined-SoA storage buffer (matching what the v2 +// pair-tree kernels marshal_pairs / pair_disjoint_tree / scatter_pairs +// / carry_copy consume): +// plane 0 (x) at vec4 indices [0, PG * M) +// plane 1 (y) at vec4 indices [PG * M, 2 * PG * M) +// per-element layout: PG=2 vec4 at [PG*elem, PG*elem+1]. +// M (elements per plane) is passed via params.y so this shader uses a +// single storage binding instead of two subviews of the same buffer — +// the subview path tripped a silent dispatch no-op on M2 Chrome 148 +// because plane-y's byte offset (PG*M*16 = 8256 for M=258) is not a +// multiple of WebGPU's default minStorageBufferOffsetAlignment of 256. +// +// Per (subtask s, slot k) thread with slot = s * input_size + k: +// pt_idx = val_idx[slot] +// active_sums[PG * slot + v] = new_point_x[PG * pt_idx + v] +// active_sums[PG * M + PG * slot + v] = new_point_y[PG * pt_idx + v] +// for v in {0, 1}. +// +// The copy is a raw element copy — destination element bytes equal +// source element bytes; no unpack / pack needed. Sign handling stays at +// finalize (cuZK encodes signed slices via bucket index, not via point +// negation). + +const PG: u32 = 2u; + +@group(0) @binding(0) +var val_idx: array; +@group(0) @binding(1) +var new_point_x: array>; +@group(0) @binding(2) +var new_point_y: array>; +@group(0) @binding(3) +var active_sums: array>; + +// params.x = total_slots (num_subtasks * input_size, OR per-window +// input_size when the caller binds val_idx as a per-window subview) +// params.y = M (elements per plane in active_sums) +@group(0) @binding(4) +var params: vec4; + +@compute +@workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let slot = gid.x; + let total = params[0]; + if (slot >= total) { + return; + } + + let M = params[1]; + let pt_idx = val_idx[slot]; + + let plane_x_base = PG * slot; + let plane_y_base = PG * M + PG * slot; + let src_x = PG * pt_idx; + let src_y = PG * pt_idx; + + active_sums[plane_x_base + 0u] = new_point_x[src_x + 0u]; + active_sums[plane_x_base + 1u] = new_point_x[src_x + 1u]; + active_sums[plane_y_base + 0u] = new_point_y[src_y + 0u]; + active_sums[plane_y_base + 1u] = new_point_y[src_y + 1u]; + + {{{ recompile }}} +} +`; + +export const csr_to_v2_meta = `// Companion to csr_to_v2_active_sums: derives the per-bucket counts and +// subtask-relative offsets that drive the v2 pair-tree planner. +// +// row_ptr layout: per subtask, num_columns + 1 entries forming a +// CSR-style prefix sum. row_ptr[s * (num_columns + 1) + b + 1] - +// row_ptr[s * (num_columns + 1) + b] is the count of points in bucket +// b of subtask s, and the begin value is the subtask-relative start +// offset within val_idx and active_sums. +// +// One thread per (subtask, bucket) emits one (count, offset) pair. + +@group(0) @binding(0) +var row_ptr: array; +@group(0) @binding(1) +var active_counts: array; +@group(0) @binding(2) +var active_offsets: array; + +// params[0] = num_columns +// params[1] = total_buckets (num_subtasks * num_columns) +@group(0) @binding(3) +var params: vec4; + +@compute +@workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let id = gid.x; + let total = params[1]; + if (id >= total) { + return; + } + + let num_columns = params[0]; + let subtask = id / num_columns; + let bucket_local = id % num_columns; + let rp_offset = subtask * (num_columns + 1u); + + let begin = row_ptr[rp_offset + bucket_local]; + let end = row_ptr[rp_offset + bucket_local + 1u]; + + active_counts[id] = end - begin; + active_offsets[id] = begin; + + {{{ recompile }}} +} +`; + export const decompose_scalars_signed_only = `// Scalars-only variant of \`convert_point_coords_and_decompose_scalars\`. // Reads 32-byte LE scalars from a packed u32 buffer and writes one // shifted-signed bucket index per scalar per subtask into \`chunks\`. @@ -6908,6 +9116,70 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { {{{ recompile }}} }`; +export const v2_to_running = `// Boundary adapter from the v2 bin-packed pair-tree's per-window +// active_sums buffer (combined SoA, plane 0 = X / plane 1 = Y at vec4 +// indices [PG*elem + v]) to the production running_x / running_y / +// bucket_active layout that batch_affine_finalize_collect consumes. +// +// Per-window dispatch: one thread per (subtask, bucket_local). The +// caller binds the per-window active_sums (combined SoA), the final +// counts and offsets emitted by the planner's last level, and views of +// the global running_x / running_y / bucket_active arrays offset by +// subtask_idx * num_columns so a single bucket_global is addressable +// via gid.x. +// +// For non-empty buckets the v2 pair-tree has reduced the bucket to one +// packed-Montgomery point sitting at active_sums[final_offsets[b]] in +// the input plane layout. We copy that element into running_x / +// running_y at the matching bucket_global slot (packed 8x u32 = two +// vec4 per element, same layout production already uses when packed). +// Empty buckets only set bucket_active = 0 — running_x / running_y are +// left untouched; finalize is gated on bucket_active and never reads +// the unwritten slot. + +const PG: u32 = 2u; + +@group(0) @binding(0) var active_sums: array>; +@group(0) @binding(1) var final_counts: array; +@group(0) @binding(2) var final_offsets: array; +@group(0) @binding(3) var running_x: array>; +@group(0) @binding(4) var running_y: array>; +@group(0) @binding(5) var bucket_active: array; +@group(0) @binding(6) var params: vec4; +// params.x = num_columns (active per-window bucket count) +// params.y = M (elements per plane in the v2 active_sums buffer) + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let bucket_local = gid.x; + let num_columns = params.x; + let M = params.y; + if (bucket_local >= num_columns) { + return; + } + + let count = final_counts[bucket_local]; + if (count == 0u) { + bucket_active[bucket_local] = 0u; + return; + } + + bucket_active[bucket_local] = 1u; + + let slot = final_offsets[bucket_local]; + let plane_x_base = PG * slot; + let plane_y_base = PG * M + PG * slot; + let dst = PG * bucket_local; + + running_x[dst + 0u] = active_sums[plane_x_base + 0u]; + running_x[dst + 1u] = active_sums[plane_x_base + 1u]; + running_y[dst + 0u] = active_sums[plane_y_base + 0u]; + running_y[dst + 1u] = active_sums[plane_y_base + 1u]; + + {{{ recompile }}} +} +`; + export const by_inverse = `// Bernstein-Yang safegcd inversion for the BN254 base field, WGSL port. // // This file will grow over sub-steps 1.3-1.5 of the WebGPU MSM rewrite plan @@ -9644,6 +11916,56 @@ fn mulhilo2(a: vec2, b: vec2) -> vec4 { } `; +export const packed_field = `// Packed 256-bit field-element storage helpers for v2 MSM. +// +// Storage convention: every field-element buffer is \`array>\` +// with logical stride 2 vec4s per element (8 × u32 = 32 bytes, +// canonical little-endian 256-bit value, value < q < 2^254). +// +// Conversions between the packed storage layout and the 20×13-bit +// \`BigInt\` arithmetic representation happen ONLY at the storage I/O +// boundary (field_load_*, field_store, fold_packed_pair). Once loaded, +// values live as BigInt limbs for the entire kernel body and only +// repack on the final write. This matches the bench_batch_affine design +// that hit ~22 ns/pair on M2; the prior PackedField-wrapper design +// repacked between every mont and paid ~2× the cost. +// +// PRECONDITION: this partial must be included after bigint_funcs, +// montgomery_product_funcs, field_funcs, by_inverse_a_funcs, and after +// the host has injected unpack256_to_limbs and pack_limbs_to_256 (those +// come from the decoupledPackUnpackWgsl() generator in shader_manager). + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +fn field_load_ro(idx: u32, src: ptr>, read>) -> BigInt { + var w: array; + let q0 = (*src)[2u * idx]; + let q1 = (*src)[2u * idx + 1u]; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn field_load_rw(idx: u32, src: ptr>, read_write>) -> BigInt { + var w: array; + let q0 = (*src)[2u * idx]; + let q1 = (*src)[2u * idx + 1u]; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn field_store(idx: u32, dst: ptr>, read_write>, val: ptr) { + let w = pack_limbs_to_256(val); + (*dst)[2u * idx] = vec4(w[0], w[1], w[2], w[3]); + (*dst)[2u * idx + 1u] = vec4(w[4], w[5], w[6], w[7]); +} +`; + export const structs = `struct Point { x: BigInt, y: BigInt, diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_carry_copy_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_carry_copy_bench.template.wgsl new file mode 100644 index 000000000000..50f409778d97 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_carry_copy_bench.template.wgsl @@ -0,0 +1,54 @@ +{{> structs }} + +// Carry-copy kernel for the bin-packed pair-tree MSM bucket-accumulate. +// +// For each carry slot t, copies one packed (x, y) point from +// active_sums_old[carry_plan[2*t + 0]] to +// active_sums_new[carry_plan[2*t + 1]]. +// +// Used when a bucket has an odd active count at the current level: +// floor(N_b / 2) elements get paired and produce floor(N_b / 2) sums +// in the next level, plus the (N_b mod 2 == 1) carry element propagates +// forward unchanged. +// +// Pure memory shuffle, no field arithmetic. +// +// params.x = T (number of carry-copies / threads) +// params.y = M_old (active_sums_old size, vec4-stride scaling) +// params.z = M_new (active_sums_new size, vec4-stride scaling) + +const PG: u32 = 2u; + +@group(0) @binding(0) var carry_plan: array; +@group(0) @binding(1) var active_sums_old: array>; +@group(0) @binding(2) var active_sums_new: array>; +@group(0) @binding(3) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let M_old = params.y; + let M_new = params.z; + let t = gid.x; + if (t >= T) { return; } + + let src_idx = carry_plan[2u * t + 0u]; + let dst_idx = carry_plan[2u * t + 1u]; + + let old_plane_x = 0u * PG * M_old; + let old_plane_y = 1u * PG * M_old; + let new_plane_x = 0u * PG * M_new; + let new_plane_y = 1u * PG * M_new; + + let src_x = old_plane_x + PG * src_idx; + let src_y = old_plane_y + PG * src_idx; + let dst_x = new_plane_x + PG * dst_idx; + let dst_y = new_plane_y + PG * dst_idx; + + active_sums_new[dst_x + 0u] = active_sums_old[src_x + 0u]; + active_sums_new[dst_x + 1u] = active_sums_old[src_x + 1u]; + active_sums_new[dst_y + 0u] = active_sums_old[src_y + 0u]; + active_sums_new[dst_y + 1u] = active_sums_old[src_y + 1u]; + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_carry_copy_prod.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_carry_copy_prod.template.wgsl new file mode 100644 index 000000000000..c3b1b12787ed --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_carry_copy_prod.template.wgsl @@ -0,0 +1,44 @@ +{{> structs }} + +// Carry-copy kernel — prod variant for the v2 pair-tree integration. +// num_carries is read from the planner's totals[1] and dispatch is +// indirect via totals[7..9]. + +const PG: u32 = 2u; + +@group(0) @binding(0) var carry_plan: array; +@group(0) @binding(1) var active_sums_old: array>; +@group(0) @binding(2) var active_sums_new: array>; +@group(0) @binding(3) var totals: array; +@group(0) @binding(4) var consts: vec4; +// consts.x = M_old +// consts.y = M_new + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = totals[1]; + let M_old = consts.x; + let M_new = consts.y; + let t = gid.x; + if (t >= T) { return; } + + let src_idx = carry_plan[2u * t + 0u]; + let dst_idx = carry_plan[2u * t + 1u]; + + let old_plane_x = 0u * PG * M_old; + let old_plane_y = 1u * PG * M_old; + let new_plane_x = 0u * PG * M_new; + let new_plane_y = 1u * PG * M_new; + + let src_x = old_plane_x + PG * src_idx; + let src_y = old_plane_y + PG * src_idx; + let dst_x = new_plane_x + PG * dst_idx; + let dst_y = new_plane_y + PG * dst_idx; + + active_sums_new[dst_x + 0u] = active_sums_old[src_x + 0u]; + active_sums_new[dst_x + 1u] = active_sums_old[src_x + 1u]; + active_sums_new[dst_y + 0u] = active_sums_old[src_y + 0u]; + active_sums_new[dst_y + 1u] = active_sums_old[src_y + 1u]; + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_fused_super_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_fused_super_bench.template.wgsl new file mode 100644 index 000000000000..6bb5e6d964d7 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_fused_super_bench.template.wgsl @@ -0,0 +1,157 @@ +{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Fused super-kernel for the bin-packed pair-tree MSM bucket-accumulate. +// +// Combines marshal + disjoint + scatter into one kernel. Each thread t +// handles one chunk of S pairs: +// 1. Read 2*S source indices from chunk_plan (idx_l, idx_r per slot). +// 2. Read S destination indices from scatter_plan. +// 3. Load S pair-x values from active_sums_old, compute S dx values +// and forward prefix product, all in registers. +// 4. Single fr_inv_by_a on the prefix product. +// 5. Backward peel: per slot k from S-1 down to 0: +// - load .x and .y for both operands +// - lean affine add -> R_x, R_y +// - write directly to active_sums_new at scatter_plan[t*S + k] +// - update inv for next (smaller-k) iteration +// +// vs v2 (4 kernels: marshal, disjoint, scatter, carry): the chain_buf +// and tempOut scratch buffers are eliminated. All intermediate state +// lives in registers. Per-level dispatch count drops from 4 to 2 +// (fused + carry). +// +// PARAMS: +// params.x = T_chunks (active threads, one per chunk) +// params.y = M_old (active_sums_old vec4-stride length) +// params.z = M_new (active_sums_new vec4-stride length) +// +// Layout (both active_sums buffers): 2 planes (P.x, P.y), PG=2 vec4 per +// element. plane_p flat vec4 base = p * PG * M, element e at offset +// PG * e. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var chunk_plan: array; +@group(0) @binding(1) var scatter_plan: array; +@group(0) @binding(2) var active_sums_old: array>; +@group(0) @binding(3) var active_sums_new: array>; +@group(0) @binding(4) var params: vec4; + +fn load_active_x(idx: u32, M: u32) -> BigInt { + let plane_base = 0u * PG * M; + let base = plane_base + PG * idx; + let q0 = active_sums_old[base + 0u]; + let q1 = active_sums_old[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn load_active_y(idx: u32, M: u32) -> BigInt { + let plane_base = 1u * PG * M; + let base = plane_base + PG * idx; + let q0 = active_sums_old[base + 0u]; + let q1 = active_sums_old[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_active_new(plane: u32, idx: u32, M: u32, val: ptr) { + let plane_base = plane * PG * M; + let base = plane_base + PG * idx; + let w = pack_limbs_to_256(val); + active_sums_new[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + active_sums_new[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let M_old = params.y; + let M_new = params.z; + let t = gid.x; + if (t >= T) { return; } + + let chunk_base = 2u * S * t; + + // Forward: compute S dx values and accumulate prefix product. + // Read pair indices from chunk_plan, load .x for each operand, compute dx. + var pref: array; + var acc: BigInt = get_r(); + for (var k: u32 = 0u; k < S; k = k + 1u) { + let idx_l = chunk_plan[chunk_base + 2u * k + 0u]; + let idx_r = chunk_plan[chunk_base + 2u * k + 1u]; + var p_lx: BigInt = load_active_x(idx_l, M_old); + var p_rx: BigInt = load_active_x(idx_r, M_old); + var dx: BigInt = fr_sub(&p_rx, &p_lx); + if (k == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[k] = acc; + } + + // Single inversion per chunk. + var inv: BigInt = fr_inv_by_a(acc); + + // Backward peel: emit S pair sums, scatter to active_sums_new. + for (var jj: u32 = 0u; jj < S; jj = jj + 1u) { + let k = S - 1u - jj; + let idx_l = chunk_plan[chunk_base + 2u * k + 0u]; + let idx_r = chunk_plan[chunk_base + 2u * k + 1u]; + + var p_lx: BigInt = load_active_x(idx_l, M_old); + var p_ly: BigInt = load_active_y(idx_l, M_old); + var p_rx: BigInt = load_active_x(idx_r, M_old); + var p_ry: BigInt = load_active_y(idx_r, M_old); + + var inv_dx: BigInt; + if (k == 0u) { + inv_dx = inv; + } else { + var pp = pref[k - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda: BigInt = fr_sub(&p_ry, &p_ly); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_lx); + r_x = fr_sub(&r_x, &p_rx); + var r_y: BigInt = fr_sub(&p_lx, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &p_ly); + + let dst_idx = scatter_plan[t * S + k]; + store_active_new(0u, dst_idx, M_new, &r_x); + store_active_new(1u, dst_idx, M_new, &r_y); + + if (k > 0u) { + var dx_back: BigInt = fr_sub(&p_rx, &p_lx); + inv = montgomery_product(&inv, &dx_back); + } + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_chain_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_chain_bench.template.wgsl new file mode 100644 index 000000000000..538be64e43a7 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_chain_bench.template.wgsl @@ -0,0 +1,91 @@ +{{> structs }} + +// Marshal kernel for the bench-msm-chain pipeline. Transposes a CSR +// point list (sorted by bucket) into the strided SoA layout the +// ba_rev_packed_carry_bench chain kernel consumes. +// +// Input layout (point_pool): +// 2 planes (P.x, P.y), each PG=2 vec4 per element, params.y elements total. +// Plane p at point idx i: vec4 indices p*PG*N + PG*i + {0,1}. +// Convention: point_pool[0] is the "decoy" — used as the seed for every +// chunk so the chain kernel's first dx (= P_0.x - seed.x) is well- +// defined. csr_indices values are in [1, N), never 0. +// +// Output layout (chain_buf): +// 4 planes (A.x, A.y, P.x, P.y), each PG=2 vec4 per element, T*S +// elements per plane. Plane p at strided element e = t + i*T: vec4 +// indices p*PG*(T*S) + PG*e + {0,1}. +// +// Per chunk-thread t: +// - csr_start = chunk_plan[2*t + 1] (chunk_plan[2*t] = bucket_id, unused here) +// - Seed at index t (planes 0,1) := point_pool[0] (universal decoy) +// - For i in 0..S: P_i at index e = t + i*T (planes 2,3) +// := point_pool[csr_indices[csr_start + i]] +// +// The chain kernel then produces S pair-sums per chunk. The S/2 odd- +// indexed outputs (R_1, R_3, ..., R_{S-1}) are disjoint pair sums of +// {P_0..P_{S-1}}; the even outputs (R_0, R_2, ...) incorporate the +// decoy or share a P with the next odd output and are discarded by the +// subsequent reduce pass. +// +// Pure memory-shuffle kernel: no field arithmetic. Reads are coalesced +// because consecutive threads t, t+1 read adjacent csr_indices entries +// and the gathered point coords are written to adjacent vec4 slots +// (PG*e for e=t, t+1, ...). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var csr_indices: array; +@group(0) @binding(1) var chunk_plan: array; +@group(0) @binding(2) var point_pool: array>; +@group(0) @binding(3) var chain_buf: array>; +@group(0) @binding(4) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let N = params.y; + let t = gid.x; + if (t >= T) { return; } + + let csr_start = chunk_plan[2u * t + 1u]; + + let chain_N = T * S; + let chain_plane = PG * chain_N; + let chain_ax_base = 0u * chain_plane; + let chain_ay_base = 1u * chain_plane; + let chain_px_base = 2u * chain_plane; + let chain_py_base = 3u * chain_plane; + + let pool_plane = PG * N; + let pool_px_base = 0u * pool_plane; + let pool_py_base = 1u * pool_plane; + + // Seed (A.x, A.y at index t) := point_pool[0] (decoy). + let decoy_x_off = pool_px_base + PG * 0u; + let decoy_y_off = pool_py_base + PG * 0u; + let seed_x_off = chain_ax_base + PG * t; + let seed_y_off = chain_ay_base + PG * t; + chain_buf[seed_x_off + 0u] = point_pool[decoy_x_off + 0u]; + chain_buf[seed_x_off + 1u] = point_pool[decoy_x_off + 1u]; + chain_buf[seed_y_off + 0u] = point_pool[decoy_y_off + 0u]; + chain_buf[seed_y_off + 1u] = point_pool[decoy_y_off + 1u]; + + // Gather S points from csr_indices[csr_start..csr_start+S] into the + // strided P-planes at indices e = t + i*T for i in 0..S. + for (var i = 0u; i < S; i = i + 1u) { + let pt_idx = csr_indices[csr_start + i]; + let e = t + i * T; + let pool_x_off = pool_px_base + PG * pt_idx; + let pool_y_off = pool_py_base + PG * pt_idx; + let chain_px_off = chain_px_base + PG * e; + let chain_py_off = chain_py_base + PG * e; + chain_buf[chain_px_off + 0u] = point_pool[pool_x_off + 0u]; + chain_buf[chain_px_off + 1u] = point_pool[pool_x_off + 1u]; + chain_buf[chain_py_off + 0u] = point_pool[pool_y_off + 0u]; + chain_buf[chain_py_off + 1u] = point_pool[pool_y_off + 1u]; + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_pairs_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_pairs_bench.template.wgsl new file mode 100644 index 000000000000..a83210bc4ade --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_pairs_bench.template.wgsl @@ -0,0 +1,79 @@ +{{> structs }} + +// Marshal kernel for the bin-packed pair-tree MSM bucket-accumulate. +// +// Reads (idx_l, idx_r) operand indices per pair from chunk_plan, +// fetches the corresponding packed 8x u32 points from an active_sums +// buffer (2-plane SoA), and writes them into the disjoint kernel's +// strided input layout. +// +// Used both at level 0 (active_sums = bucket-sorted point pool) and +// at levels 1+ (active_sums = previous level's pair-sum + carry +// outputs). The kernel is bucket-agnostic; the planner has packed +// each chunk's S pairs from whatever buckets fit, and chunk_plan +// encodes the operand source indices. +// +// chunk_plan layout: 2 * S u32 per chunk +// chunk_plan[2 * (t * S + k) + 0] = idx_left (active_sums index) +// chunk_plan[2 * (t * S + k) + 1] = idx_right (active_sums index) +// +// active_sums layout: 2 planes (P.x, P.y), PG=2 vec4 per element, +// M_in elements per plane (params.y). +// +// chain_buf layout: 2 planes (P.x, P.y), PG=2 vec4 per element, +// 2 * S * T elements per plane. Slot (t, 2k+0) holds left, slot +// (t, 2k+1) holds right at the disjoint kernel's strided positions +// e = t + i * T for i = 2k, 2k+1. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var chunk_plan: array; +@group(0) @binding(1) var active_sums: array>; +@group(0) @binding(2) var chain_buf: array>; +@group(0) @binding(3) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let M_in = params.y; + let t = gid.x; + if (t >= T) { return; } + + let chain_N = 2u * S * T; + let chain_plane_x = 0u * PG * chain_N; + let chain_plane_y = 1u * PG * chain_N; + + let active_plane_x = 0u * PG * M_in; + let active_plane_y = 1u * PG * M_in; + + let chunk_base = 2u * S * t; + for (var k: u32 = 0u; k < S; k = k + 1u) { + let idx_l = chunk_plan[chunk_base + 2u * k + 0u]; + let idx_r = chunk_plan[chunk_base + 2u * k + 1u]; + + let e_l = t + (2u * k + 0u) * T; + let e_r = t + (2u * k + 1u) * T; + + let src_lx = active_plane_x + PG * idx_l; + let src_ly = active_plane_y + PG * idx_l; + let src_rx = active_plane_x + PG * idx_r; + let src_ry = active_plane_y + PG * idx_r; + + let dst_lx = chain_plane_x + PG * e_l; + let dst_ly = chain_plane_y + PG * e_l; + let dst_rx = chain_plane_x + PG * e_r; + let dst_ry = chain_plane_y + PG * e_r; + + chain_buf[dst_lx + 0u] = active_sums[src_lx + 0u]; + chain_buf[dst_lx + 1u] = active_sums[src_lx + 1u]; + chain_buf[dst_ly + 0u] = active_sums[src_ly + 0u]; + chain_buf[dst_ly + 1u] = active_sums[src_ly + 1u]; + chain_buf[dst_rx + 0u] = active_sums[src_rx + 0u]; + chain_buf[dst_rx + 1u] = active_sums[src_rx + 1u]; + chain_buf[dst_ry + 0u] = active_sums[src_ry + 0u]; + chain_buf[dst_ry + 1u] = active_sums[src_ry + 1u]; + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_pairs_prod.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_pairs_prod.template.wgsl new file mode 100644 index 000000000000..3cff285af1ca --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_pairs_prod.template.wgsl @@ -0,0 +1,65 @@ +{{> structs }} + +// Marshal kernel — prod variant for the v2 pair-tree integration. +// +// Same indexing math as ba_marshal_pairs_bench. The only structural +// change: the per-level T (= num_chunks) is read from the planner's +// totals[3] storage output instead of a host-set uniform, and the +// host dispatches via dispatchWorkgroupsIndirect(totals, 16). This +// dispatches exactly ceil(num_chunks / WG) workgroups so no pad +// chunks are computed. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var chunk_plan: array; +@group(0) @binding(1) var active_sums: array>; +@group(0) @binding(2) var chain_buf: array>; +@group(0) @binding(3) var totals: array; +@group(0) @binding(4) var consts: vec4; +// consts.x = M_in + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = totals[3]; + let M_in = consts.x; + let t = gid.x; + if (t >= T) { return; } + + let chain_N = 2u * S * T; + let chain_plane_x = 0u * PG * chain_N; + let chain_plane_y = 1u * PG * chain_N; + + let active_plane_x = 0u * PG * M_in; + let active_plane_y = 1u * PG * M_in; + + let chunk_base = 2u * S * t; + for (var k: u32 = 0u; k < S; k = k + 1u) { + let idx_l = chunk_plan[chunk_base + 2u * k + 0u]; + let idx_r = chunk_plan[chunk_base + 2u * k + 1u]; + + let e_l = t + (2u * k + 0u) * T; + let e_r = t + (2u * k + 1u) * T; + + let src_lx = active_plane_x + PG * idx_l; + let src_ly = active_plane_y + PG * idx_l; + let src_rx = active_plane_x + PG * idx_r; + let src_ry = active_plane_y + PG * idx_r; + + let dst_lx = chain_plane_x + PG * e_l; + let dst_ly = chain_plane_y + PG * e_l; + let dst_rx = chain_plane_x + PG * e_r; + let dst_ry = chain_plane_y + PG * e_r; + + chain_buf[dst_lx + 0u] = active_sums[src_lx + 0u]; + chain_buf[dst_lx + 1u] = active_sums[src_lx + 1u]; + chain_buf[dst_ly + 0u] = active_sums[src_ly + 0u]; + chain_buf[dst_ly + 1u] = active_sums[src_ly + 1u]; + chain_buf[dst_rx + 0u] = active_sums[src_rx + 0u]; + chain_buf[dst_rx + 1u] = active_sums[src_rx + 1u]; + chain_buf[dst_ry + 0u] = active_sums[src_ry + 0u]; + chain_buf[dst_ry + 1u] = active_sums[src_ry + 1u]; + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_tree_l0_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_tree_l0_bench.template.wgsl new file mode 100644 index 000000000000..4b1539600e64 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_marshal_tree_l0_bench.template.wgsl @@ -0,0 +1,64 @@ +{{> structs }} + +// Marshal kernel for the bench-msm-tree pair-tree pipeline: transposes +// a CSR-sorted point index list into the 2-plane strided SoA layout +// the ba_pair_disjoint_tree kernel consumes at level 0. Pure memory +// shuffle, no field arithmetic. +// +// Input (point_pool): +// 2 planes (P.x, P.y), each PG=2 vec4 per element, N pool elements. +// Plane p flat vec4 indices: p*PG*N + PG*i + {0,1}. +// +// Output (chain_buf): +// 2 planes (P.x, P.y), each PG=2 vec4 per element, 2*S*T elements +// per plane. Plane p at strided element e = t + i*T: vec4 indices +// p*PG*(2*S*T) + PG*e + {0,1}. +// +// Per chunk-thread t with CSR slice [csr_start, csr_start + 2*S): +// For i in 0..2*S: +// pt_idx = csr_indices[csr_start + i] +// copy point_pool[pt_idx] (P.x, P.y) into chain_buf at e = t + i*T + +const S: u32 = {{ s }}u; +const TWOS: u32 = 2u * S; +const PG: u32 = 2u; + +@group(0) @binding(0) var csr_indices: array; +@group(0) @binding(1) var chunk_plan: array; +@group(0) @binding(2) var point_pool: array>; +@group(0) @binding(3) var chain_buf: array>; +@group(0) @binding(4) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let N = params.y; + let t = gid.x; + if (t >= T) { return; } + + let csr_start = chunk_plan[2u * t + 1u]; + + let chain_N = TWOS * T; + let chain_plane = PG * chain_N; + let chain_px_base = 0u * chain_plane; + let chain_py_base = 1u * chain_plane; + + let pool_plane = PG * N; + let pool_px_base = 0u * pool_plane; + let pool_py_base = 1u * pool_plane; + + for (var i: u32 = 0u; i < TWOS; i = i + 1u) { + let pt_idx = csr_indices[csr_start + i]; + let e = t + i * T; + let pool_x_off = pool_px_base + PG * pt_idx; + let pool_y_off = pool_py_base + PG * pt_idx; + let chain_px_off = chain_px_base + PG * e; + let chain_py_off = chain_py_base + PG * e; + chain_buf[chain_px_off + 0u] = point_pool[pool_x_off + 0u]; + chain_buf[chain_px_off + 1u] = point_pool[pool_x_off + 1u]; + chain_buf[chain_py_off + 0u] = point_pool[pool_y_off + 0u]; + chain_buf[chain_py_off + 1u] = point_pool[pool_y_off + 1u]; + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_bench.template.wgsl new file mode 100644 index 000000000000..d5a83e646f1b --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_bench.template.wgsl @@ -0,0 +1,138 @@ +{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Disjoint pair-sum kernel — each thread reduces 2*S input points to S +// disjoint pair sums R_k = P_{2k} + P_{2k+1} (k in 0..S) using the +// same forward-prefix / single-inversion / backward-peel batched- +// inverse pattern as ba_rev_packed_carry, but with NO load-carry +// overlap. Every kernel-output is a distinct pair sum suitable as +// input to the next level of a pair-tree reduction — closes the 50% +// kernel-efficiency loss inherent in the streaming chain kernel. +// +// Storage: SoA-packed 8x u32 per field (PG=2 vec4/elem). +// Input planes (binding 0): +// plane 0 (P.x): PG * N_in vec4, N_in = 2*S*T +// plane 1 (P.y): PG * N_in vec4 +// Output planes (binding 2): +// plane 0 (R.x): PG * N_out vec4, N_out = S*T +// plane 1 (R.y): PG * N_out vec4 +// +// Thread t reads P_i = (inp[plane c at index t + i*T] : c in {0,1}) for +// i in 0..2S (strided => coalesced). Pair k pairs adjacent strided +// slots: (P_{2k}, P_{2k+1}). Output R_k is written at index t + k*T in +// plane c of outp (also strided, coalesced). +// +// dx values dx_k = P_{2k+1}.x - P_{2k}.x are all mutually independent +// (no shared inputs across k), so the standard Montgomery batched +// inverse trick applies as-is: ONE fr_inv_by_a per chunk of S. +// +// Same Karatsuba+Yuval montmul and BY-safegcd fr_inv_by_a as the +// production stack and the chain kernel. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var inp: array>; +@group(0) @binding(1) var unused: array>; +@group(0) @binding(2) var outp: array>; +@group(0) @binding(3) var params: vec4; + +fn load_in(plane: u32, t: u32, i: u32, T: u32, N_in: u32) -> BigInt { + let plane_base = plane * PG * N_in; + let base = plane_base + PG * (t + i * T); + let q0 = inp[base + 0u]; + let q1 = inp[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_out(plane: u32, t: u32, k: u32, T: u32, N_out: u32, val: ptr) { + let plane_base = plane * PG * N_out; + let base = plane_base + PG * (t + k * T); + let w = pack_limbs_to_256(val); + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let N_in = params.x; + let T = params.y; + let N_out = N_in / 2u; + + let t = gid.x; + if (t >= T) { return; } + + // Forward: prefix product of S independent dx values. + var pref: array; + var acc: BigInt = get_r(); + for (var k: u32 = 0u; k < S; k = k + 1u) { + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T, N_in); + var dx: BigInt = fr_sub(&p_rx, &p_lx); + if (k == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[k] = acc; + } + + // One BY-safegcd inversion amortised over all S pair sums. + var inv: BigInt = fr_inv_by_a(acc); + + // Backward peel: emit S disjoint pair sums. + for (var jj: u32 = 0u; jj < S; jj = jj + 1u) { + let k = S - 1u - jj; + + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T, N_in); + var p_ly: BigInt = load_in(1u, t, 2u * k + 0u, T, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T, N_in); + var p_ry: BigInt = load_in(1u, t, 2u * k + 1u, T, N_in); + + var inv_dx: BigInt; + if (k == 0u) { + inv_dx = inv; + } else { + var pp = pref[k - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda: BigInt = fr_sub(&p_ry, &p_ly); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_lx); + r_x = fr_sub(&r_x, &p_rx); + var r_y: BigInt = fr_sub(&p_lx, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &p_ly); + + store_out(0u, t, k, T, N_out, &r_x); + store_out(1u, t, k, T, N_out, &r_y); + + // Advance inv to 1/pref[k-1] for the next (smaller) iteration. + if (k > 0u) { + var dx_back: BigInt = fr_sub(&p_rx, &p_lx); + inv = montgomery_product(&inv, &dx_back); + } + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_tree_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_tree_bench.template.wgsl new file mode 100644 index 000000000000..d12b6176ede1 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_tree_bench.template.wgsl @@ -0,0 +1,169 @@ +{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Disjoint pair-sum kernel — tree variant. Each thread reduces 2*S +// input points to S disjoint pair sums R_k = P_{2k} + P_{2k+1}, using +// one batched fr_inv_by_a per chunk of S. +// +// vs ba_pair_disjoint_bench: writes outputs in the LAYOUT THE NEXT +// PAIR-TREE LEVEL EXPECTS AS INPUT, eliminating the need for an +// intervening marshal/reshuffle dispatch between levels. +// +// Strided read at level k: thread t reads input slot i at flat +// in_pos(t, i) = t + i * T_curr (i in [0, 2*S)) +// +// Strided write that next level reads correctly: thread t writes +// output slot i at flat +// out_pos(t, i) = (t >> 1) + (i + S * (t & 1)) * (T_curr >> 1) +// +// Derivation: next level uses T_next = T_curr / 2 threads. For +// next-level thread t_n = t >> 1 to read its 2*S inputs in the right +// pair-tree order (first S from prev thread (2*t_n), next S from prev +// thread (2*t_n + 1)), the current level's output slots interleave: +// odd-t writes go into the upper-S input slots of the next level's +// thread (t >> 1), even-t into the lower-S slots. +// +// This preserves the per-bucket-pair invariant: at every level, the +// disjoint pairs (P_{2j}, P_{2j+1}) belong to the same bucket pool, +// so the lean affine formula is always combining points whose dx is +// well-defined. +// +// PARAMS: +// params.x = N_in = 2 * S * T_curr (total input elements per plane) +// params.y = T_curr +// +// LAYOUT (both input and output buffers): +// 2 planes (P.x, P.y), PG=2 vec4 per element. +// Plane p flat index for vec4 access: p * PG * N_buf + PG * e + {0,1} +// where N_buf is the elements-per-plane for that buffer. +// Input buffer's N_buf = 2 * S * T_curr (= N_in). +// Output buffer's N_buf = S * T_curr (= N_in / 2). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var inp: array>; +@group(0) @binding(1) var unused: array>; +@group(0) @binding(2) var outp: array>; +@group(0) @binding(3) var params: vec4; + +fn load_in(plane: u32, t: u32, i: u32, T: u32, N_in: u32) -> BigInt { + let plane_base = plane * PG * N_in; + let base = plane_base + PG * (t + i * T); + let q0 = inp[base + 0u]; + let q1 = inp[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_out_tree(plane: u32, t: u32, k: u32, T_curr: u32, N_out: u32, val: ptr) { + // Tree write: out_pos(t, k) = (t >> 1) + (k + S * (t & 1)) * (T_curr >> 1) + // Lands in next-level strided read at index (t >> 1) with slot + // (k + S * (t & 1)). + let t_next = t >> 1u; + let slot_in_next = k + S * (t & 1u); + let T_next = T_curr >> 1u; + let plane_base = plane * PG * N_out; + let elem = t_next + slot_in_next * T_next; + let base = plane_base + PG * elem; + let w = pack_limbs_to_256(val); + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn store_out_simple(plane: u32, t: u32, k: u32, T_curr: u32, N_out: u32, val: ptr) { + // Final-level simple strided write: out_pos(t, k) = t + k * T_curr. + // Used when there is no next pair-tree level (T_curr == 1 thread, or + // the host indicates this is the last reduction step). + let plane_base = plane * PG * N_out; + let elem = t + k * T_curr; + let base = plane_base + PG * elem; + let w = pack_limbs_to_256(val); + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let N_in = params.x; + let T_curr = params.y; + let final_flag = params.z; // non-zero => use simple strided write + let N_out = N_in / 2u; + + let t = gid.x; + if (t >= T_curr) { return; } + + var pref: array; + var acc: BigInt = get_r(); + for (var k: u32 = 0u; k < S; k = k + 1u) { + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in); + var dx: BigInt = fr_sub(&p_rx, &p_lx); + if (k == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[k] = acc; + } + + var inv: BigInt = fr_inv_by_a(acc); + + for (var jj: u32 = 0u; jj < S; jj = jj + 1u) { + let k = S - 1u - jj; + + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in); + var p_ly: BigInt = load_in(1u, t, 2u * k + 0u, T_curr, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in); + var p_ry: BigInt = load_in(1u, t, 2u * k + 1u, T_curr, N_in); + + var inv_dx: BigInt; + if (k == 0u) { + inv_dx = inv; + } else { + var pp = pref[k - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda: BigInt = fr_sub(&p_ry, &p_ly); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_lx); + r_x = fr_sub(&r_x, &p_rx); + var r_y: BigInt = fr_sub(&p_lx, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &p_ly); + + if (final_flag != 0u) { + store_out_simple(0u, t, k, T_curr, N_out, &r_x); + store_out_simple(1u, t, k, T_curr, N_out, &r_y); + } else { + store_out_tree(0u, t, k, T_curr, N_out, &r_x); + store_out_tree(1u, t, k, T_curr, N_out, &r_y); + } + + if (k > 0u) { + var dx_back: BigInt = fr_sub(&p_rx, &p_lx); + inv = montgomery_product(&inv, &dx_back); + } + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_tree_prod.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_tree_prod.template.wgsl new file mode 100644 index 000000000000..3c7b7b3d4504 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_pair_disjoint_tree_prod.template.wgsl @@ -0,0 +1,119 @@ +{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Disjoint pair-sum kernel — prod variant for the v2 pair-tree +// integration. Same disjoint pair-sum math as +// ba_pair_disjoint_tree_bench (suffix-product single fr_inv_by_a per +// chunk + lean affine add); the per-level T (= num_chunks) is read +// from the planner's totals[3] storage output and the dispatch happens +// indirectly so only real chunks run. Always uses the final-mode +// strided write (matches what ba_scatter_pairs_prod expects). +// +// LAYOUT: same as the bench variant. Combined-SoA input/output (2 +// planes, PG=2 vec4 per element, plane-major then element-major then +// vec4 within an element). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var inp: array>; +@group(0) @binding(1) var unused: array>; +@group(0) @binding(2) var outp: array>; +@group(0) @binding(3) var totals: array; + +fn load_in(plane: u32, t: u32, i: u32, T: u32, N_in: u32) -> BigInt { + let plane_base = plane * PG * N_in; + let base = plane_base + PG * (t + i * T); + let q0 = inp[base + 0u]; + let q1 = inp[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_out_simple(plane: u32, t: u32, k: u32, T_curr: u32, N_out: u32, val: ptr) { + let plane_base = plane * PG * N_out; + let elem = t + k * T_curr; + let base = plane_base + PG * elem; + let w = pack_limbs_to_256(val); + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T_curr = totals[3]; + let N_in = 2u * S * T_curr; + let N_out = S * T_curr; + + let t = gid.x; + if (t >= T_curr) { return; } + + var pref: array; + var acc: BigInt = get_r(); + for (var k: u32 = 0u; k < S; k = k + 1u) { + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in); + var dx: BigInt = fr_sub(&p_rx, &p_lx); + if (k == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[k] = acc; + } + + var inv: BigInt = fr_inv_by_a(acc); + + for (var jj: u32 = 0u; jj < S; jj = jj + 1u) { + let k = S - 1u - jj; + + var p_lx: BigInt = load_in(0u, t, 2u * k + 0u, T_curr, N_in); + var p_ly: BigInt = load_in(1u, t, 2u * k + 0u, T_curr, N_in); + var p_rx: BigInt = load_in(0u, t, 2u * k + 1u, T_curr, N_in); + var p_ry: BigInt = load_in(1u, t, 2u * k + 1u, T_curr, N_in); + + var inv_dx: BigInt; + if (k == 0u) { + inv_dx = inv; + } else { + var pp = pref[k - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda: BigInt = fr_sub(&p_ry, &p_ly); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x: BigInt = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &p_lx); + r_x = fr_sub(&r_x, &p_rx); + var r_y: BigInt = fr_sub(&p_lx, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &p_ly); + + store_out_simple(0u, t, k, T_curr, N_out, &r_x); + store_out_simple(1u, t, k, T_curr, N_out, &r_y); + + if (k > 0u) { + var dx_back: BigInt = fr_sub(&p_rx, &p_lx); + inv = montgomery_product(&inv, &dx_back); + } + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_bench.template.wgsl new file mode 100644 index 000000000000..1d49e4298849 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_bench.template.wgsl @@ -0,0 +1,93 @@ +{{> structs }} + +// GPU-side bin-packing planner for the v3 MSM bucket-accumulate +// pipeline. One thread per bucket; uses atomicAdd to reserve global +// per-pair slots in chunk_plan / scatter_plan and per-carry slots in +// carry_plan, then writes that bucket's entries. +// +// Inputs (per current level): +// counts: array per-bucket active count +// offsets: array per-bucket starting index in active_sums_old +// +// Outputs (filled in by this kernel for the current level): +// chunk_plan: array 2 u32 per (chunk_id, slot) — pair operand indices +// scatter_plan: array 1 u32 per (chunk_id, slot) — destination in active_sums_new +// carry_plan: array 2 u32 per carry slot — (src in old, dst in new) +// totals: array> [0]=total pairs, [1]=total carries, [2]=total new actives +// new_counts: array per-bucket new active count (for next level) +// new_offsets: array per-bucket new offset in active_sums_new (for next level) +// +// Convention: discard slot = M_new - 1 (the highest index in +// active_sums_new). Pad pair source indices = (pad_l_idx, pad_r_idx) +// supplied via params. All non-real chunk_plan / scatter_plan slots +// must be pre-padded to (pad_l_idx, pad_r_idx) and discard_idx by the +// host before each planner dispatch. +// +// params.x = B (bucket count) +// params.y = S (chunk size, slots per chunk) +// (pad_l_idx / pad_r_idx / discard_idx live in the pre-padded +// arrays, not in params) + +const S: u32 = {{ s }}u; + +@group(0) @binding(0) var counts: array; +@group(0) @binding(1) var offsets: array; +@group(0) @binding(2) var chunk_plan: array; +@group(0) @binding(3) var scatter_plan: array; +@group(0) @binding(4) var carry_plan: array; +@group(0) @binding(5) var totals: array>; +@group(0) @binding(6) var new_counts: array; +@group(0) @binding(7) var new_offsets: array; +@group(0) @binding(8) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let B = params.x; + let b = gid.x; + if (b >= B) { return; } + + let n = counts[b]; + let pair_count = n / 2u; + let carry_flag = n & 1u; + let nc = pair_count + carry_flag; + new_counts[b] = nc; + + // Atomic offset reservation. Each bucket gets a unique non-overlapping + // range in the global arrays. Atomic order is non-deterministic but + // that's fine: bucket b records its assigned offsets and uses them + // consistently for its own chunk_plan / scatter_plan / new_offsets + // writes. Different buckets land in different ranges by construction. + let my_pair_off = atomicAdd(&totals[0u], pair_count); + let my_carry_off = atomicAdd(&totals[1u], carry_flag); + let my_new_off = atomicAdd(&totals[2u], nc); + new_offsets[b] = my_new_off; + + let bucket_base = offsets[b]; + + // Write this bucket's pair entries into chunk_plan / scatter_plan. + // Loop bounded by pair_count (variable per bucket; typically ~16 + // for Poisson(λ=32)). The TAIL_CAP-style compile-time bound used + // by ba_tail_reduce isn't strictly needed here since this kernel + // doesn't do field arithmetic; the loop is plain integer writes. + // We still bound it by a compile-time constant for WGSL static + // analysis purposes. + let PAIR_CAP: u32 = {{ pair_cap }}u; + for (var j: u32 = 0u; j < PAIR_CAP; j = j + 1u) { + if (j >= pair_count) { break; } + let global_slot = my_pair_off + j; + let chunk_id = global_slot / S; + let slot_in_chunk = global_slot % S; + let cp_base = 2u * (chunk_id * S + slot_in_chunk); + chunk_plan[cp_base + 0u] = bucket_base + 2u * j; + chunk_plan[cp_base + 1u] = bucket_base + 2u * j + 1u; + scatter_plan[chunk_id * S + slot_in_chunk] = my_new_off + j; + } + + if (carry_flag != 0u) { + let cs = my_carry_off; + carry_plan[2u * cs + 0u] = bucket_base + n - 1u; + carry_plan[2u * cs + 1u] = my_new_off + pair_count; + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_bench.template.wgsl new file mode 100644 index 000000000000..789747d86f82 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_bench.template.wgsl @@ -0,0 +1,170 @@ +{{> structs }} + +// Optimal single-kernel GPU bin-packing planner for the MSM +// bucket-accumulate pair-tree. +// +// One workgroup of TPB threads processes B buckets. Each thread +// handles PER_THREAD = B / TPB buckets via a contiguous slice +// [tid * PER_THREAD, (tid+1) * PER_THREAD). +// +// Phase A — Per-thread local scan +// For each of its PER_THREAD buckets, compute (pair_count, carry_flag, +// new_count). Accumulate per-thread totals (sum across the thread's +// slice). Keep the per-bucket triples in registers; we will re-scan +// them in Phase B. +// +// Phase B — Workgroup-wide Hillis-Steele scan (3 in parallel) +// Scan the per-thread totals for pair, carry, new across the TPB +// threads in shared memory. Result: each thread gets the global +// prefix sum at the START of its slice (= base offset for its first +// bucket). +// +// Phase C — Per-thread scatter +// For each bucket in the thread's slice (in order), use the running +// thread-local offset to compute global pair_offset_b and write the +// pair_count[b] chunk_plan entries plus the (optional) carry_plan +// entry. Update local running offsets. Write new_counts[b] and +// new_offsets[b] for the next level. +// +// Phase D — One thread writes totals. +// totals[0] = total_pairs, totals[1] = total_carries, +// totals[2] = total_new_actives. +// +// Single dispatch. No atomics. No host sync. Scales to B = TPB * +// PER_THREAD (e.g. 256 * 32 = 8192) within one workgroup. Larger B +// requires multi-workgroup scan + global combine (out of scope here). +// +// Compile-time constants: +// TPB : workgroup size (e.g. 256) +// PER_THREAD : buckets per thread (e.g. 16 for B=4096, 32 for B=8192) +// PAIR_CAP : bound on per-bucket pair count (Poisson(λ=32) tail +// is ~30; choose 64 for safety) +// S : chunk size in pairs (e.g. 16) + +const TPB: u32 = {{ workgroup_size }}u; +const PER_THREAD: u32 = {{ per_thread }}u; +const PAIR_CAP: u32 = {{ pair_cap }}u; +const S: u32 = {{ s }}u; + +@group(0) @binding(0) var counts: array; +@group(0) @binding(1) var offsets: array; +@group(0) @binding(2) var chunk_plan: array; +@group(0) @binding(3) var scatter_plan: array; +@group(0) @binding(4) var carry_plan: array; +@group(0) @binding(5) var new_counts: array; +@group(0) @binding(6) var new_offsets: array; +@group(0) @binding(7) var totals: array; +@group(0) @binding(8) var params: vec4; +// params.x = B + +// Workgroup-shared running prefixes for the 3 scans. +var pair_scan: array; +var carry_scan: array; +var new_scan: array; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; + let B = params.x; + + // Phase A: per-thread local read + accumulate. + // Keep PER_THREAD bucket triples in registers (small array). + var local_pc: array; + var local_cf: array; + var local_nc: array; + var sum_p: u32 = 0u; + var sum_c: u32 = 0u; + var sum_n: u32 = 0u; + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tid * PER_THREAD + k; + var pc: u32 = 0u; + var cf: u32 = 0u; + var nc: u32 = 0u; + if (b < B) { + let n = counts[b]; + pc = n / 2u; + cf = n & 1u; + nc = pc + cf; + } + local_pc[k] = pc; + local_cf[k] = cf; + local_nc[k] = nc; + sum_p += pc; + sum_c += cf; + sum_n += nc; + } + + // Phase B: workgroup-wide Hillis-Steele inclusive scan over per- + // thread totals (3 scans interleaved). + pair_scan[tid] = sum_p; + carry_scan[tid] = sum_c; + new_scan[tid] = sum_n; + workgroupBarrier(); + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var add_p: u32 = 0u; + var add_c: u32 = 0u; + var add_n: u32 = 0u; + if (tid >= stride) { + add_p = pair_scan[tid - stride]; + add_c = carry_scan[tid - stride]; + add_n = new_scan[tid - stride]; + } + workgroupBarrier(); + if (tid >= stride) { + pair_scan[tid] = pair_scan[tid] + add_p; + carry_scan[tid] = carry_scan[tid] + add_c; + new_scan[tid] = new_scan[tid] + add_n; + } + workgroupBarrier(); + } + // pair_scan[tid] is now inclusive prefix. Exclusive base = inclusive - own_sum. + var local_pair_off: u32 = pair_scan[tid] - sum_p; + var local_carry_off: u32 = carry_scan[tid] - sum_c; + var local_new_off: u32 = new_scan[tid] - sum_n; + + // Phase D: thread 0 writes totals (using the FINAL inclusive scan). + if (tid == TPB - 1u) { + totals[0] = pair_scan[tid]; + totals[1] = carry_scan[tid]; + totals[2] = new_scan[tid]; + } + + // Phase C: per-thread scatter. + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tid * PER_THREAD + k; + if (b >= B) { break; } + + let pc = local_pc[k]; + let cf = local_cf[k]; + let nc = local_nc[k]; + new_counts[b] = nc; + new_offsets[b] = local_new_off; + + let bucket_base = offsets[b]; + + // Pair entries: bounded loop, break at pc. + for (var j: u32 = 0u; j < PAIR_CAP; j = j + 1u) { + if (j >= pc) { break; } + let global_slot = local_pair_off + j; + let chunk_id = global_slot / S; + let slot_in_chunk = global_slot % S; + let cp_base = 2u * (chunk_id * S + slot_in_chunk); + chunk_plan[cp_base + 0u] = bucket_base + 2u * j; + chunk_plan[cp_base + 1u] = bucket_base + 2u * j + 1u; + scatter_plan[chunk_id * S + slot_in_chunk] = local_new_off + j; + } + + // Carry entry (if odd count). + if (cf != 0u) { + let cs = local_carry_off; + carry_plan[2u * cs + 0u] = bucket_base + counts[b] - 1u; + carry_plan[2u * cs + 1u] = local_new_off + pc; + } + + local_pair_off += pc; + local_carry_off += cf; + local_new_off += nc; + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_prod.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_prod.template.wgsl new file mode 100644 index 000000000000..811118992dc5 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_planner_v2_prod.template.wgsl @@ -0,0 +1,188 @@ +{{> structs }} + +// Production GPU bin-packing planner for the v2 pair-tree integration. +// +// Same algorithm as ba_planner_v2_bench (one workgroup of TPB threads, +// per-thread local scan, workgroup-wide Hillis-Steele scan over the +// three running sums, per-thread scatter) but extends the totals +// output with the indirect-dispatch counts the production marshal / +// disjoint / scatter / carry kernels need: +// +// totals[0] = total_pairs +// totals[1] = total_carries +// totals[2] = total_new +// totals[3] = num_chunks = max(1, (total_pairs + S - 1) / S) +// totals[4] = marshal/disjoint/scatter dispatch X (= ceil(num_chunks / WGI)) +// totals[5] = 1 +// totals[6] = 1 +// totals[7] = carry dispatch X (= ceil(total_carries / WGI)) +// totals[8] = 1 +// totals[9] = 1 +// +// The four prod-variant downstream kernels (ba_marshal_pairs_prod, +// ba_pair_disjoint_tree_prod, ba_scatter_pairs_prod, ba_carry_copy_prod) +// read num_chunks and total_carries from this same totals storage +// buffer so a single planner dispatch fully drives the level's runtime +// shape with zero wasted-pad-chunk compute. The host orchestrator +// reuses the totals buffer as the indirect-dispatch source via +// dispatchWorkgroupsIndirect(totals, 16) for marshal/disjoint/scatter +// (totals u32 indices 4..6) and dispatchWorkgroupsIndirect(totals, 28) +// for carry (totals u32 indices 7..9). +// +// Compile-time constants: +// TPB : workgroup size (e.g. 256) +// PER_THREAD : buckets per thread +// PAIR_CAP : per-bucket pair-count bound +// S : chunk size in pairs +// WGI : downstream kernel workgroup size — must match the +// workgroup_size of ba_marshal_pairs_prod / +// ba_pair_disjoint_tree_prod / ba_scatter_pairs_prod / +// ba_carry_copy_prod. + +const TPB: u32 = {{ workgroup_size }}u; +const PER_THREAD: u32 = {{ per_thread }}u; +const PAIR_CAP: u32 = {{ pair_cap }}u; +const S: u32 = {{ s }}u; +const WGI: u32 = {{ wgi }}u; + +@group(0) @binding(0) var counts: array; +@group(0) @binding(1) var offsets: array; +@group(0) @binding(2) var chunk_plan: array; +@group(0) @binding(3) var scatter_plan: array; +@group(0) @binding(4) var carry_plan: array; +@group(0) @binding(5) var new_counts: array; +@group(0) @binding(6) var new_offsets: array; +@group(0) @binding(7) var totals: array; +@group(0) @binding(8) var params: vec4; +// params.x = B +// params.y = pad_left_idx (active_sums index used for chunk_plan tail pad left operand) +// params.z = pad_right_idx (chunk_plan tail pad right operand; must differ from pad_left_idx in x) +// params.w = discard_idx (scatter_plan tail dst; slot that the next level never reads) + +var pair_scan: array; +var carry_scan: array; +var new_scan: array; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; + let B = params.x; + + var local_pc: array; + var local_cf: array; + var local_nc: array; + var sum_p: u32 = 0u; + var sum_c: u32 = 0u; + var sum_n: u32 = 0u; + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tid * PER_THREAD + k; + var pc: u32 = 0u; + var cf: u32 = 0u; + var nc: u32 = 0u; + if (b < B) { + let n = counts[b]; + pc = n / 2u; + cf = n & 1u; + nc = pc + cf; + } + local_pc[k] = pc; + local_cf[k] = cf; + local_nc[k] = nc; + sum_p += pc; + sum_c += cf; + sum_n += nc; + } + + pair_scan[tid] = sum_p; + carry_scan[tid] = sum_c; + new_scan[tid] = sum_n; + workgroupBarrier(); + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var add_p: u32 = 0u; + var add_c: u32 = 0u; + var add_n: u32 = 0u; + if (tid >= stride) { + add_p = pair_scan[tid - stride]; + add_c = carry_scan[tid - stride]; + add_n = new_scan[tid - stride]; + } + workgroupBarrier(); + if (tid >= stride) { + pair_scan[tid] = pair_scan[tid] + add_p; + carry_scan[tid] = carry_scan[tid] + add_c; + new_scan[tid] = new_scan[tid] + add_n; + } + workgroupBarrier(); + } + var local_pair_off: u32 = pair_scan[tid] - sum_p; + var local_carry_off: u32 = carry_scan[tid] - sum_c; + var local_new_off: u32 = new_scan[tid] - sum_n; + + if (tid == TPB - 1u) { + let tp = pair_scan[tid]; + let tc = carry_scan[tid]; + let tn = new_scan[tid]; + totals[0] = tp; + totals[1] = tc; + totals[2] = tn; + let num_chunks = (tp + S - 1u) / S; + totals[3] = num_chunks; + totals[4] = (num_chunks + WGI - 1u) / WGI; + totals[5] = 1u; + totals[6] = 1u; + totals[7] = (tc + WGI - 1u) / WGI; + totals[8] = 1u; + totals[9] = 1u; + } + + for (var k: u32 = 0u; k < PER_THREAD; k = k + 1u) { + let b = tid * PER_THREAD + k; + if (b >= B) { break; } + + let pc = local_pc[k]; + let cf = local_cf[k]; + let nc = local_nc[k]; + new_counts[b] = nc; + new_offsets[b] = local_new_off; + + let bucket_base = offsets[b]; + + for (var j: u32 = 0u; j < PAIR_CAP; j = j + 1u) { + if (j >= pc) { break; } + let global_slot = local_pair_off + j; + let chunk_id = global_slot / S; + let slot_in_chunk = global_slot % S; + let cp_base = 2u * (chunk_id * S + slot_in_chunk); + chunk_plan[cp_base + 0u] = bucket_base + 2u * j; + chunk_plan[cp_base + 1u] = bucket_base + 2u * j + 1u; + scatter_plan[chunk_id * S + slot_in_chunk] = local_new_off + j; + } + + if (cf != 0u) { + let cs = local_carry_off; + carry_plan[2u * cs + 0u] = bucket_base + counts[b] - 1u; + carry_plan[2u * cs + 1u] = local_new_off + pc; + } + + local_pair_off += pc; + local_carry_off += cf; + local_new_off += nc; + } + + workgroupBarrier(); + if (tid == TPB - 1u) { + let tp = pair_scan[tid]; + let num_chunks = (tp + S - 1u) / S; + let pad_end = num_chunks * S; + let pad_left = params.y; + let pad_right = params.z; + let discard_idx = params.w; + for (var i: u32 = tp; i < pad_end; i = i + 1u) { + chunk_plan[2u * i + 0u] = pad_left; + chunk_plan[2u * i + 1u] = pad_right; + scatter_plan[i] = discard_idx; + } + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl new file mode 100644 index 000000000000..f0f6a7206649 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_rev_packed_carry_bench.template.wgsl @@ -0,0 +1,172 @@ +{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +// MSM-integrated bucket-accumulate batch-affine kernel — packed 8x u32 +// storage + decoupled (full-ILP) pack/unpack + reversed direction + +// resident-accumulator load-carry. Drives the canonical +// ba_rev_packed_carry benchmark that reached ~22 ns/pair on M2 / Chrome +// 148 (-55% vs the production batch-affine kernel). +// +// Math is byte-identical to ba_msm_bucket_bench: forward running +// prefix-product of the S dx values in a private array, ONE +// fr_inv_by_a per chunk of S, backward peel with the lean affine +// formula (dx recomputed free in the backward pass), resident +// accumulator A.x kept in registers across the whole chunk (load-carry: +// A_{i+1} := P_i so the forward and backward passes share one global +// P_i.x load per iteration). Same Karatsuba+Yuval montmul and BY-safegcd +// fr_inv_by_a as the production stack. +// +// The single structural change from ba_msm_bucket_bench: +// global storage is the packed 254-bit value stored as 8x u32 +// (32 bytes/elem == 2x vec4), not the 20x 13-bit-limb BigInt +// (80 bytes/elem == 5x vec4). Unpack into 20x13-bit limbs only +// in-register at load and repack on store. The pack/unpack is the +// decoupled full-ILP straight-line form (injected below as +// unpack256_to_limbs / pack_limbs_to_256): 20 mutually-independent +// compile-time-constant-indexed limb extractions, zero loop-carried +// bit-cursor dependency chain. This cuts global traffic 2.5x (the +// dominant cost in the memory-bound batch-affine kernel) at a +// sub-cycle in-register cost. +// +// LAYOUT: packed elem = 2 vec4; for each of the 4 input planes +// (A.x, A.y, P.x, P.y) and 2 output planes (R.x, R.y), plane c holds +// N elements at indices c*2*N + 2*e + {0,1}. params.x = N (total +// point-adds), params.y = T (thread count = N/S). +// +// Thread t streams points e = t + i*T for i in 0..S (strided => fully +// coalesced across the apply phase). The "left" operand of add i is the +// running accumulator A_i; A_0 is the per-thread seed (plane 0/1 at +// e=t), A_{i+1} := P_i (load-carry; same global address as forward +// pass's P_i load, no extra global traffic). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; // 8 u32 packed limbs / 4 = 2 vec4 groups + +@group(0) @binding(0) var inp: array>; +@group(0) @binding(1) var unused: array>; +@group(0) @binding(2) var outp: array>; +@group(0) @binding(3) var params: vec4; + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +fn load_be_packed(plane_base: u32, e: u32, N: u32) -> BigInt { + // plane_base is in vec4 units; per plane: 2*N vec4 (PG=2). + let base = plane_base + PG * e; + let q0 = inp[base + 0u]; + let q1 = inp[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_be_packed(plane_base: u32, e: u32, N: u32, val: ptr) { + let w = pack_limbs_to_256(val); + let base = plane_base + PG * e; + outp[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + outp[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let N = params.x; + let T = params.y; + let t = gid.x; + if (t >= T) { return; } + + // Plane bases in vec4 units. Each plane spans PG*N vec4. + let plane = PG * N; + let ax_base = 0u * plane; + let ay_base = 1u * plane; + let px_base = 2u * plane; + let py_base = 3u * plane; + + // Resident accumulator A.x stays in registers across the whole + // chunk (drives the forward dx prefix chain). A.y is only needed in + // the backward peel and is re-loaded there from the same SoA plane. + var acc_x = load_be_packed(ax_base, t, N); + + // Forward pass: running prefix-product of the S dx values + // dx_i = P_i.x - A_i.x. A_i is the prefix accumulator (resident). + var pref: array; + var acc: BigInt = get_r(); + for (var i = 0u; i < S; i = i + 1u) { + let e = t + i * T; + var p_x = load_be_packed(px_base, e, N); + var dx = fr_sub(&p_x, &acc_x); + if (i == 0u) { + acc = dx; + } else { + acc = montgomery_product(&acc, &dx); + } + pref[i] = acc; + // Resident accumulator advances along the streamed chain: + // A_0 is the seed, A_{i+1} := P_i. Points are independent + // (P_i.x != A_i.x) so every dx is a well-defined nonzero + // difference. inv_dx is deferred to the backward pass (ONE + // fr_inv_by_a per chunk of S); A stays in registers throughout. + acc_x = p_x; + } + + var inv: BigInt = fr_inv_by_a(acc); + + // Backward peel + lean affine formula (dx recomputed free). + for (var jj = 0u; jj < S; jj = jj + 1u) { + let i = S - 1u - jj; + let e = t + i * T; + var p_x = load_be_packed(px_base, e, N); + var p_y = load_be_packed(py_base, e, N); + + // A_i (left operand): A_0 is the seed, A_i = P_{i-1} for i>0 + // (matches the forward acc_x recurrence; points independent so + // dx = P_i.x - A_i.x is always well-defined and nonzero). + var a_x: BigInt; + var a_y: BigInt; + if (i == 0u) { + a_x = load_be_packed(ax_base, t, N); + a_y = load_be_packed(ay_base, t, N); + } else { + let ep = t + (i - 1u) * T; + a_x = load_be_packed(px_base, ep, N); + a_y = load_be_packed(py_base, ep, N); + } + + var inv_dx: BigInt; + if (i == 0u) { + inv_dx = inv; + } else { + var pp = pref[i - 1u]; + inv_dx = montgomery_product(&inv, &pp); + } + + var lambda = fr_sub(&p_y, &a_y); + lambda = montgomery_product(&lambda, &inv_dx); + var r_x = montgomery_product(&lambda, &lambda); + r_x = fr_sub(&r_x, &a_x); + r_x = fr_sub(&r_x, &p_x); + var r_y = fr_sub(&a_x, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &a_y); + + store_be_packed(0u * plane, e, N, &r_x); + store_be_packed(1u * plane, e, N, &r_y); + + if (i != 0u) { + var dx_back = fr_sub(&p_x, &a_x); + inv = montgomery_product(&inv, &dx_back); + } + } +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_scatter_pairs_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_scatter_pairs_bench.template.wgsl new file mode 100644 index 000000000000..00d14390002c --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_scatter_pairs_bench.template.wgsl @@ -0,0 +1,61 @@ +{{> structs }} + +// Scatter kernel for the bin-packed pair-tree MSM bucket-accumulate. +// +// For each (chunk t, slot k), reads R.x/R.y from the disjoint kernel's +// strided output (where it landed at flat index t + k * T after +// running with final_flag=1) and writes them to active_sums_new at +// the destination index given by scatter_plan[t * S + k]. +// +// This is the per-bucket-placement pass that re-groups pair sums for +// the next level's bin-packing planner. +// +// scatter_plan layout: 1 u32 per (chunk, slot). +// scatter_plan[t * S + k] = dst_idx (active_sums_new index) +// +// disjoint_out layout: 2 planes (R.x, R.y), PG=2 vec4 per element, +// S * T elements per plane (matches the disjoint kernel's +// final-mode simple strided write). +// +// active_sums_new layout: 2 planes (P.x, P.y), PG=2 vec4 per element, +// M_new elements per plane (params.y). + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var scatter_plan: array; +@group(0) @binding(1) var disjoint_out: array>; +@group(0) @binding(2) var active_sums_new: array>; +@group(0) @binding(3) var params: vec4; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let M_new = params.y; + let t = gid.x; + if (t >= T) { return; } + + let out_N = S * T; + let out_plane_x = 0u * PG * out_N; + let out_plane_y = 1u * PG * out_N; + + let new_plane_x = 0u * PG * M_new; + let new_plane_y = 1u * PG * M_new; + + for (var k: u32 = 0u; k < S; k = k + 1u) { + let e = t + k * T; + let dst_idx = scatter_plan[t * S + k]; + + let src_x = out_plane_x + PG * e; + let src_y = out_plane_y + PG * e; + let dst_x = new_plane_x + PG * dst_idx; + let dst_y = new_plane_y + PG * dst_idx; + + active_sums_new[dst_x + 0u] = disjoint_out[src_x + 0u]; + active_sums_new[dst_x + 1u] = disjoint_out[src_x + 1u]; + active_sums_new[dst_y + 0u] = disjoint_out[src_y + 0u]; + active_sums_new[dst_y + 1u] = disjoint_out[src_y + 1u]; + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_scatter_pairs_prod.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_scatter_pairs_prod.template.wgsl new file mode 100644 index 000000000000..4a14e8539736 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_scatter_pairs_prod.template.wgsl @@ -0,0 +1,48 @@ +{{> structs }} + +// Scatter kernel — prod variant for the v2 pair-tree integration. +// Same per-bucket placement math as ba_scatter_pairs_bench; T is read +// from the planner's totals[3] and the dispatch is indirect via +// totals[4..6]. + +const S: u32 = {{ s }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var scatter_plan: array; +@group(0) @binding(1) var disjoint_out: array>; +@group(0) @binding(2) var active_sums_new: array>; +@group(0) @binding(3) var totals: array; +@group(0) @binding(4) var consts: vec4; +// consts.x = M_new + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = totals[3]; + let M_new = consts.x; + let t = gid.x; + if (t >= T) { return; } + + let out_N = S * T; + let out_plane_x = 0u * PG * out_N; + let out_plane_y = 1u * PG * out_N; + + let new_plane_x = 0u * PG * M_new; + let new_plane_y = 1u * PG * M_new; + + for (var k: u32 = 0u; k < S; k = k + 1u) { + let e = t + k * T; + let dst_idx = scatter_plan[t * S + k]; + + let src_x = out_plane_x + PG * e; + let src_y = out_plane_y + PG * e; + let dst_x = new_plane_x + PG * dst_idx; + let dst_y = new_plane_y + PG * dst_idx; + + active_sums_new[dst_x + 0u] = disjoint_out[src_x + 0u]; + active_sums_new[dst_x + 1u] = disjoint_out[src_x + 1u]; + active_sums_new[dst_y + 0u] = disjoint_out[src_y + 0u]; + active_sums_new[dst_y + 1u] = disjoint_out[src_y + 1u]; + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_tail_reduce_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_tail_reduce_bench.template.wgsl new file mode 100644 index 000000000000..06e0a57e26df --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_tail_reduce_bench.template.wgsl @@ -0,0 +1,118 @@ +{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +// Tail kernel for the bench-msm-tree pipeline: reduces a single +// tail-sized bucket (count < 2*S) to one sum per thread. Each thread +// reads its bucket's count points sequentially from the SoA-packed +// point pool and accumulates them via direct affine adds (one +// fr_inv_by_a per step). +// +// Pragmatic v1 — no batched inversion across threads. Each step pays +// one full fr_inv_by_a (~80 mont mul equivalents). For typical +// Poisson(lambda=16) MSM workloads, tail buckets carry a minority of +// total work (~10-30%); the contribution to overall bucket-accumulate +// ns/in-pt is small enough that this simple design is acceptable for +// a v1 complete-replacement kernel set. A workgroup-scan +// batched-inversion variant is a follow-on optimisation that would +// drop tail cost to ~25 ns/add (matching the main pair-tree). +// +// Bindings: +// binding 0: csr_indices — sorted point indices, 1-based (index 0 reserved). +// binding 1: tail_plan — three u32 per tail thread: +// [bucket_id, csr_start, count]. +// binding 2: point_pool — SoA-packed pool (2 planes, PG=2 vec4/elem). +// binding 3: bucket_sums — SoA-packed output (2 planes, PG=2 vec4/bucket), +// one packed point per bucket. Pre-zeroed by host. +// binding 4: params — params.x=T (tail thread count), +// params.y=N (pool size), +// params.z=B (bucket_sums slot count). +// +// Bounded loop: the per-thread accumulate loop iterates up to compile- +// time TAIL_CAP = 2*S - 1, breaking early when i >= count. No +// data-dependent unbounded loops. + +const TAIL_CAP: u32 = {{ tail_cap }}u; +const PG: u32 = 2u; + +@group(0) @binding(0) var csr_indices: array; +@group(0) @binding(1) var tail_plan: array; +@group(0) @binding(2) var point_pool: array>; +@group(0) @binding(3) var bucket_sums: array>; +@group(0) @binding(4) var params: vec4; + +fn load_pool(plane: u32, idx: u32, N: u32) -> BigInt { + let plane_base = plane * PG * N; + let base = plane_base + PG * idx; + let q0 = point_pool[base + 0u]; + let q1 = point_pool[base + 1u]; + var w: array; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn store_bucket(plane: u32, b: u32, B: u32, val: ptr) { + let plane_base = plane * PG * B; + let base = plane_base + PG * b; + let w = pack_limbs_to_256(val); + bucket_sums[base + 0u] = vec4(w[0], w[1], w[2], w[3]); + bucket_sums[base + 1u] = vec4(w[4], w[5], w[6], w[7]); +} + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let T = params.x; + let N = params.y; + let B = params.z; + + let t = gid.x; + if (t >= T) { return; } + + let bucket_id = tail_plan[3u * t + 0u]; + let csr_start = tail_plan[3u * t + 1u]; + let count = tail_plan[3u * t + 2u]; + + if (count == 0u) { return; } + + var acc_x: BigInt = load_pool(0u, csr_indices[csr_start], N); + var acc_y: BigInt = load_pool(1u, csr_indices[csr_start], N); + + for (var i: u32 = 1u; i < TAIL_CAP; i = i + 1u) { + if (i >= count) { break; } + let pt_idx = csr_indices[csr_start + i]; + var p_x: BigInt = load_pool(0u, pt_idx, N); + var p_y: BigInt = load_pool(1u, pt_idx, N); + var dx: BigInt = fr_sub(&p_x, &acc_x); + var inv_dx: BigInt = fr_inv_by_a(dx); + var dy: BigInt = fr_sub(&p_y, &acc_y); + var lambda: BigInt = montgomery_product(&dy, &inv_dx); + var lambda_sq: BigInt = montgomery_product(&lambda, &lambda); + var r_x: BigInt = fr_sub(&lambda_sq, &acc_x); + r_x = fr_sub(&r_x, &p_x); + var r_y: BigInt = fr_sub(&acc_x, &r_x); + r_y = montgomery_product(&lambda, &r_y); + r_y = fr_sub(&r_y, &acc_y); + acc_x = r_x; + acc_y = r_y; + } + + store_bucket(0u, bucket_id, B, &acc_x); + store_bucket(1u, bucket_id, B, &acc_y); + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl new file mode 100644 index 000000000000..aadbd8c79261 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_fused_wg_scan.template.wgsl @@ -0,0 +1,222 @@ +{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +{{> packed_field_funcs }} + +// Workgroup-scan fused batch-affine round kernel for v2 MSM. +// +// Mirrors `bench_batch_affine.template.wgsl`'s phases A/B/C/D — TPB +// threads cooperating on BATCH_SIZE = TPB*BS pairs per workgroup with +// one fr_inv_by_a per workgroup — adapted for the MSM pipeline: +// - storage is packed 8×u32 per field element (vs the bench's +// BigInt-array storage); conversions happen only at field_load_* +// and field_store, every kernel-local var holds BigInt limbs +// - loads are bucket-indirect via `pair_target_meta` (vs the bench's +// flat `inputs[pair_base + *]`) +// +// PHASES +// A) Per-thread serial prefix product over BS pairs. Each thread +// writes its prefix-product chain to `prefix[batch_base + k]` +// (global storage) and captures `block_total` in a register. +// B) Workgroup-shared Hillis-Steele forward + backward scan over the +// TPB block_totals (log2 TPB rounds of mont mul). +// C) Thread 0 inverts the global product via fr_inv_by_a (ONE per +// workgroup). Broadcasts to wg_inv_total. +// D) Each thread back-walks its chunk, recovers inv_dx for each pair +// from (wg_inv_total * block_excl_prefix * block_excl_suffix * +// prev_in_chunk_prefix), emits lean affine add, scatters to +// running_x/y[bucket]. +// +// SAFETY +// The scheduler emits at most one pair per (subtask, bucket) per +// round. Within a workgroup's BATCH_SIZE slots, every `bucket` is +// distinct → no intra-workgroup RAW hazard on the running_x/y +// scatters. Across workgroups in the same subtask: disjoint slot +// ranges → still distinct buckets. Across subtasks (Z dim): different +// bucket ranges entirely. +// +// DISPATCH +// workgroup_size = TPB. Workgroups in X = ceil(n / (TPB*BS)). +// Workgroups in Z = num_subtasks. The atomicLoad of count_buf and +// subsequent control flow are uniform within a workgroup (every +// thread sees the same `n`), but Tint can't prove that — so we never +// early-return based on it. Instead, partial-batch threads contribute +// identity to the scan and skip their work loop bodies. + +const TPB: u32 = {{ tpb }}u; +const BS: u32 = {{ bs }}u; +const BATCH_SIZE: u32 = {{ batch_size }}u; + +@group(0) @binding(0) +var val_idx: array; +@group(0) @binding(1) +var new_point_x: array>; +@group(0) @binding(2) +var new_point_y: array>; +@group(0) @binding(3) +var running_x: array>; +@group(0) @binding(4) +var running_y: array>; +@group(0) @binding(5) +var pair_target_meta: array; +@group(0) @binding(6) +var prefix_buf: array; +@group(0) @binding(7) +var count_buf: array>; + +// params[0] = num_columns (per-subtask pool stride) +// params[1] = input_size (per-subtask val_idx stride) +@group(0) @binding(8) +var params: vec4; + +var wg_fwd: array; +var wg_bwd: array; +var wg_inv_total: BigInt; + +@compute +@workgroup_size({{ tpb }}) +fn main( + @builtin(local_invocation_id) lid: vec3, + @builtin(workgroup_id) wid: vec3, +) { + let tid = lid.x; + let wg_idx = wid.x; + let subtask_idx = wid.z; + let num_columns = params[0]; + let input_size = params[1]; + + let n = atomicLoad(&count_buf[subtask_idx]); + let batch_base = wg_idx * BATCH_SIZE; + + let pool_base = subtask_idx * num_columns; + let vi_offset = subtask_idx * input_size; + + let chunk_start = tid * BS; + let chunk_pool_base = pool_base + batch_base + chunk_start; + + let in_pool = batch_base + chunk_start + BS <= n; + + // Phase A — per-thread serial prefix product. Inin_pool threads + // (chunk past the live pool) contribute identity (R = Mont 1) so + // the workgroup scan reads a sane value at every slot. + var block_total: BigInt = get_r(); + if (in_pool) { + { + let k0 = 0u; + let slot = chunk_pool_base + k0; + let bucket = pair_target_meta[2u * slot]; + let q_cursor = pair_target_meta[2u * slot + 1u]; + let pt_idx = val_idx[vi_offset + q_cursor]; + var p_x: BigInt = field_load_rw(bucket, &running_x); + var q_x: BigInt = field_load_ro(pt_idx, &new_point_x); + var dx: BigInt = fr_sub(&q_x, &p_x); + prefix_buf[chunk_pool_base + k0] = dx; + block_total = dx; + } + for (var i: u32 = 1u; i < BS; i = i + 1u) { + let slot = chunk_pool_base + i; + let bucket = pair_target_meta[2u * slot]; + let q_cursor = pair_target_meta[2u * slot + 1u]; + let pt_idx = val_idx[vi_offset + q_cursor]; + var p_x: BigInt = field_load_rw(bucket, &running_x); + var q_x: BigInt = field_load_ro(pt_idx, &new_point_x); + var dx: BigInt = fr_sub(&q_x, &p_x); + block_total = montgomery_product(&block_total, &dx); + prefix_buf[chunk_pool_base + i] = block_total; + } + } + + wg_fwd[tid] = block_total; + wg_bwd[tid] = block_total; + workgroupBarrier(); + + // Phase B — Hillis-Steele forward + backward inclusive scan. + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var fwd_x: BigInt = wg_fwd[tid]; + if (tid >= stride) { + var lhs: BigInt = wg_fwd[tid - stride]; + fwd_x = montgomery_product(&lhs, &fwd_x); + } + var bwd_x: BigInt = wg_bwd[tid]; + if (tid + stride < TPB) { + var rhs: BigInt = wg_bwd[tid + stride]; + bwd_x = montgomery_product(&bwd_x, &rhs); + } + workgroupBarrier(); + wg_fwd[tid] = fwd_x; + wg_bwd[tid] = bwd_x; + workgroupBarrier(); + } + + // Phase C — single fr_inv per workgroup. + if (tid == 0u) { + var global_total: BigInt = wg_fwd[TPB - 1u]; + wg_inv_total = fr_inv_by_a(global_total); + } + workgroupBarrier(); + + // Phase D — back-walk this thread's chunk, emit lean affine adds. + if (!in_pool) { + return; + } + var block_excl_prefix: BigInt = get_r(); + if (tid > 0u) { + block_excl_prefix = wg_fwd[tid - 1u]; + } + var block_excl_suffix: BigInt = get_r(); + if (tid + 1u < TPB) { + block_excl_suffix = wg_bwd[tid + 1u]; + } + var inv_global: BigInt = wg_inv_total; + var inv_acc: BigInt = montgomery_product(&inv_global, &block_excl_prefix); + inv_acc = montgomery_product(&inv_acc, &block_excl_suffix); + + for (var off: u32 = 0u; off < BS; off = off + 1u) { + let k = BS - 1u - off; + let slot = chunk_pool_base + k; + let bucket = pair_target_meta[2u * slot]; + let q_cursor = pair_target_meta[2u * slot + 1u]; + let pt_idx = val_idx[vi_offset + q_cursor]; + + var p_x: BigInt = field_load_rw(bucket, &running_x); + var p_y: BigInt = field_load_rw(bucket, &running_y); + var q_x: BigInt = field_load_ro(pt_idx, &new_point_x); + var q_y: BigInt = field_load_ro(pt_idx, &new_point_y); + + var inv_dx: BigInt; + if (k > 0u) { + var prev_prefix: BigInt = prefix_buf[chunk_pool_base + (k - 1u)]; + inv_dx = montgomery_product(&inv_acc, &prev_prefix); + } else { + inv_dx = inv_acc; + } + + var dy: BigInt = fr_sub(&q_y, &p_y); + var lambda: BigInt = montgomery_product(&dy, &inv_dx); + var lambda_sq: BigInt = montgomery_product(&lambda, &lambda); + var t1: BigInt = fr_sub(&lambda_sq, &p_x); + var r_x: BigInt = fr_sub(&t1, &q_x); + var dx_back: BigInt = fr_sub(&p_x, &r_x); + var ldx: BigInt = montgomery_product(&lambda, &dx_back); + var r_y: BigInt = fr_sub(&ldx, &p_y); + + field_store(bucket, &running_x, &r_x); + field_store(bucket, &running_y, &r_y); + + if (k > 0u) { + var dx_k: BigInt = fr_sub(&q_x, &p_x); + inv_acc = montgomery_product(&inv_acc, &dx_k); + } + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_active_sums.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_active_sums.template.wgsl new file mode 100644 index 000000000000..7abf9766e692 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_active_sums.template.wgsl @@ -0,0 +1,71 @@ +// Layout converter for the v2 pair-tree MSM bucket-accumulate path. +// +// Materializes the bucket-major active_sums buffer by copying packed +// 8×u32 base coords from the cached_bases (new_point_x / new_point_y) +// at the indices listed in val_idx (cuZK transpose output, bucket-major +// per subtask). +// +// active_sums is one combined-SoA storage buffer (matching what the v2 +// pair-tree kernels marshal_pairs / pair_disjoint_tree / scatter_pairs +// / carry_copy consume): +// plane 0 (x) at vec4 indices [0, PG * M) +// plane 1 (y) at vec4 indices [PG * M, 2 * PG * M) +// per-element layout: PG=2 vec4 at [PG*elem, PG*elem+1]. +// M (elements per plane) is passed via params.y so this shader uses a +// single storage binding instead of two subviews of the same buffer — +// the subview path tripped a silent dispatch no-op on M2 Chrome 148 +// because plane-y's byte offset (PG*M*16 = 8256 for M=258) is not a +// multiple of WebGPU's default minStorageBufferOffsetAlignment of 256. +// +// Per (subtask s, slot k) thread with slot = s * input_size + k: +// pt_idx = val_idx[slot] +// active_sums[PG * slot + v] = new_point_x[PG * pt_idx + v] +// active_sums[PG * M + PG * slot + v] = new_point_y[PG * pt_idx + v] +// for v in {0, 1}. +// +// The copy is a raw element copy — destination element bytes equal +// source element bytes; no unpack / pack needed. Sign handling stays at +// finalize (cuZK encodes signed slices via bucket index, not via point +// negation). + +const PG: u32 = 2u; + +@group(0) @binding(0) +var val_idx: array; +@group(0) @binding(1) +var new_point_x: array>; +@group(0) @binding(2) +var new_point_y: array>; +@group(0) @binding(3) +var active_sums: array>; + +// params.x = total_slots (num_subtasks * input_size, OR per-window +// input_size when the caller binds val_idx as a per-window subview) +// params.y = M (elements per plane in active_sums) +@group(0) @binding(4) +var params: vec4; + +@compute +@workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let slot = gid.x; + let total = params[0]; + if (slot >= total) { + return; + } + + let M = params[1]; + let pt_idx = val_idx[slot]; + + let plane_x_base = PG * slot; + let plane_y_base = PG * M + PG * slot; + let src_x = PG * pt_idx; + let src_y = PG * pt_idx; + + active_sums[plane_x_base + 0u] = new_point_x[src_x + 0u]; + active_sums[plane_x_base + 1u] = new_point_x[src_x + 1u]; + active_sums[plane_y_base + 0u] = new_point_y[src_y + 0u]; + active_sums[plane_y_base + 1u] = new_point_y[src_y + 1u]; + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_meta.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_meta.template.wgsl new file mode 100644 index 000000000000..fdc595a0b3bc --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/csr_to_v2_meta.template.wgsl @@ -0,0 +1,45 @@ +// Companion to csr_to_v2_active_sums: derives the per-bucket counts and +// subtask-relative offsets that drive the v2 pair-tree planner. +// +// row_ptr layout: per subtask, num_columns + 1 entries forming a +// CSR-style prefix sum. row_ptr[s * (num_columns + 1) + b + 1] - +// row_ptr[s * (num_columns + 1) + b] is the count of points in bucket +// b of subtask s, and the begin value is the subtask-relative start +// offset within val_idx and active_sums. +// +// One thread per (subtask, bucket) emits one (count, offset) pair. + +@group(0) @binding(0) +var row_ptr: array; +@group(0) @binding(1) +var active_counts: array; +@group(0) @binding(2) +var active_offsets: array; + +// params[0] = num_columns +// params[1] = total_buckets (num_subtasks * num_columns) +@group(0) @binding(3) +var params: vec4; + +@compute +@workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let id = gid.x; + let total = params[1]; + if (id >= total) { + return; + } + + let num_columns = params[0]; + let subtask = id / num_columns; + let bucket_local = id % num_columns; + let rp_offset = subtask * (num_columns + 1u); + + let begin = row_ptr[rp_offset + bucket_local]; + let end = row_ptr[rp_offset + bucket_local + 1u]; + + active_counts[id] = end - begin; + active_offsets[id] = begin; + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/v2_to_running.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/v2_to_running.template.wgsl new file mode 100644 index 000000000000..cadcd5043753 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/v2_to_running.template.wgsl @@ -0,0 +1,62 @@ +// Boundary adapter from the v2 bin-packed pair-tree's per-window +// active_sums buffer (combined SoA, plane 0 = X / plane 1 = Y at vec4 +// indices [PG*elem + v]) to the production running_x / running_y / +// bucket_active layout that batch_affine_finalize_collect consumes. +// +// Per-window dispatch: one thread per (subtask, bucket_local). The +// caller binds the per-window active_sums (combined SoA), the final +// counts and offsets emitted by the planner's last level, and views of +// the global running_x / running_y / bucket_active arrays offset by +// subtask_idx * num_columns so a single bucket_global is addressable +// via gid.x. +// +// For non-empty buckets the v2 pair-tree has reduced the bucket to one +// packed-Montgomery point sitting at active_sums[final_offsets[b]] in +// the input plane layout. We copy that element into running_x / +// running_y at the matching bucket_global slot (packed 8x u32 = two +// vec4 per element, same layout production already uses when packed). +// Empty buckets only set bucket_active = 0 — running_x / running_y are +// left untouched; finalize is gated on bucket_active and never reads +// the unwritten slot. + +const PG: u32 = 2u; + +@group(0) @binding(0) var active_sums: array>; +@group(0) @binding(1) var final_counts: array; +@group(0) @binding(2) var final_offsets: array; +@group(0) @binding(3) var running_x: array>; +@group(0) @binding(4) var running_y: array>; +@group(0) @binding(5) var bucket_active: array; +@group(0) @binding(6) var params: vec4; +// params.x = num_columns (active per-window bucket count) +// params.y = M (elements per plane in the v2 active_sums buffer) + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let bucket_local = gid.x; + let num_columns = params.x; + let M = params.y; + if (bucket_local >= num_columns) { + return; + } + + let count = final_counts[bucket_local]; + if (count == 0u) { + bucket_active[bucket_local] = 0u; + return; + } + + bucket_active[bucket_local] = 1u; + + let slot = final_offsets[bucket_local]; + let plane_x_base = PG * slot; + let plane_y_base = PG * M + PG * slot; + let dst = PG * bucket_local; + + running_x[dst + 0u] = active_sums[plane_x_base + 0u]; + running_x[dst + 1u] = active_sums[plane_x_base + 1u]; + running_y[dst + 0u] = active_sums[plane_y_base + 0u]; + running_y[dst + 1u] = active_sums[plane_y_base + 1u]; + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/struct/packed_field.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/struct/packed_field.template.wgsl new file mode 100644 index 000000000000..9dec37f1803b --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/struct/packed_field.template.wgsl @@ -0,0 +1,48 @@ +// Packed 256-bit field-element storage helpers for v2 MSM. +// +// Storage convention: every field-element buffer is `array>` +// with logical stride 2 vec4s per element (8 × u32 = 32 bytes, +// canonical little-endian 256-bit value, value < q < 2^254). +// +// Conversions between the packed storage layout and the 20×13-bit +// `BigInt` arithmetic representation happen ONLY at the storage I/O +// boundary (field_load_*, field_store, fold_packed_pair). Once loaded, +// values live as BigInt limbs for the entire kernel body and only +// repack on the final write. This matches the bench_batch_affine design +// that hit ~22 ns/pair on M2; the prior PackedField-wrapper design +// repacked between every mont and paid ~2× the cost. +// +// PRECONDITION: this partial must be included after bigint_funcs, +// montgomery_product_funcs, field_funcs, by_inverse_a_funcs, and after +// the host has injected unpack256_to_limbs and pack_limbs_to_256 (those +// come from the decoupledPackUnpackWgsl() generator in shader_manager). + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +fn field_load_ro(idx: u32, src: ptr>, read>) -> BigInt { + var w: array; + let q0 = (*src)[2u * idx]; + let q1 = (*src)[2u * idx + 1u]; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn field_load_rw(idx: u32, src: ptr>, read_write>) -> BigInt { + var w: array; + let q0 = (*src)[2u * idx]; + let q1 = (*src)[2u * idx + 1u]; + w[0] = q0.x; w[1] = q0.y; w[2] = q0.z; w[3] = q0.w; + w[4] = q1.x; w[5] = q1.y; w[6] = q1.z; w[7] = q1.w; + return unpack256_to_limbs(w); +} + +fn field_store(idx: u32, dst: ptr>, read_write>, val: ptr) { + let w = pack_limbs_to_256(val); + (*dst)[2u * idx] = vec4(w[0], w[1], w[2], w[3]); + (*dst)[2u * idx + 1u] = vec4(w[4], w[5], w[6], w[7]); +}