From b0e2d6b03e1ef8f27fa9328530a37838daa66979 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Thu, 18 Jun 2026 15:57:24 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- backends/webgpu/runtime/ops/sdpa/Sdpa.cpp | 25 +++-- .../ops/sdpa/sdpa_compute_attn_weights.wgsl | 103 +++++++++++++---- .../ops/sdpa/sdpa_compute_attn_weights_wgsl.h | 105 ++++++++++++++---- .../runtime/ops/sdpa/sdpa_compute_out.wgsl | 95 +++++++++++++--- .../runtime/ops/sdpa/sdpa_compute_out_wgsl.h | 97 +++++++++++++--- 5 files changed, 350 insertions(+), 75 deletions(-) diff --git a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp index dd48f6f5902..fe0684e8b7c 100644 --- a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp +++ b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp @@ -26,6 +26,13 @@ namespace executorch::backends::webgpu { namespace { +// Register-tile dims; MUST match TM/TN in the reg WGSL kernels. +constexpr int64_t kSdpaTileM = 4; +constexpr int64_t kSdpaTileN = 4; +inline int64_t sdpa_ceil_div(int64_t a, int64_t b) { + return (a + b - 1) / b; +} + // Uniform param structs (all 16-byte aligned, matching the WGSL Params). struct UpdateCacheParams { uint32_t numel; @@ -464,14 +471,16 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { dynamic_pos, "update_cache(V)"); - // --- Dispatch 3: QK -> attn_weights. One thread per (h,s,c) element. + // --- Dispatch 3: QK -> attn_weights. One thread per TM x TN tile. { if (aw_floats > UINT32_MAX) { throw std::runtime_error( "WebGPU sdpa: Hq*S*context_len exceeds uint32 max"); } + const int64_t qk_tiles = Hq * sdpa_ceil_div(S, kSdpaTileM) * + sdpa_ceil_div(context_len, kSdpaTileN); const uint32_t wgc = utils::compute_1d_workgroup_count( - device, static_cast(aw_floats), qk_wg, "QK"); + device, static_cast(qk_tiles), qk_wg, "QK"); AttnWeightsParams p = make_attn_weights_params( S, Hq, Hkv, D, context_len, input_pos, g, scale); WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p)); @@ -515,12 +524,12 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { softmax_buf = ubuf; } - // --- Dispatch 5: AV -> out. One thread per (s,h,d) output element. + // --- Dispatch 5: AV -> out. One thread per TM x TN tile. { - const uint64_t out_floats = static_cast(S) * - static_cast(Hq) * static_cast(D); + const int64_t av_tiles = + Hq * sdpa_ceil_div(S, kSdpaTileM) * sdpa_ceil_div(D, kSdpaTileN); const uint32_t wgc = utils::compute_1d_workgroup_count( - device, static_cast(out_floats), av_wg, "AV"); + device, static_cast(av_tiles), av_wg, "AV"); ComputeOutParams p = make_compute_out_params(S, Hq, Hkv, D, context_len, g); WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p)); BufferBinding bindings[3] = { @@ -591,9 +600,11 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { AttnWeightsParams qp = make_attn_weights_params(S, Hq, Hkv, D, ctx, pos, g, scale); wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp)); + const int64_t qk_tiles = Hq * sdpa_ceil_div(S, kSdpaTileM) * + sdpa_ceil_div(ctx, kSdpaTileN); const uint32_t qk_wgc = utils::compute_1d_workgroup_count( gr.device(), - static_cast(aw_floats), + static_cast(qk_tiles), qk_wg, "QK(resize)"); gr.dispatch_at(qk_idx).workgroup_count_x = qk_wgc; diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl index b9905a59376..f7dae08fb8e 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl @@ -19,37 +19,102 @@ const NEG_INF: f32 = -1.0e30; override wg_size: u32 = 64; +const TM: u32 = 4u; +const TN: u32 = 4u; + +fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (s >= params.S) { + return r; + } + let base = s * params.Hq * params.D + h * params.D; + if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; } + if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; } + if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; } + if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; } + return r; +} + +fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (c >= params.context_len) { + return r; + } + let base = c * params.Hkv * params.D + kvh * params.D; + if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; } + if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; } + if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; } + if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; } + return r; +} + +fn store_qk(s: u32, c: u32, h: u32, raw: f32) { + if (s >= params.S || c >= params.context_len) { + return; + } + var val = raw * params.scale; + // Causal mask: position c may not attend beyond s + input_pos. + if (c > s + params.input_pos) { + val = NEG_INF; + } + let idx = h * params.S * params.context_len + s * params.context_len + c; + t_attn_weights[idx] = val; +} + @compute @workgroup_size(wg_size, 1, 1) fn main(@builtin(global_invocation_id) gid: vec3) { - let total = params.Hq * params.S * params.context_len; - let idx = gid.x; - if (idx >= total) { + let nrt = (params.S + TM - 1u) / TM; + let nct = (params.context_len + TN - 1u) / TN; + let tiles = nrt * nct; + let total = tiles * params.Hq; + if (gid.x >= total) { return; } - let c = idx % params.context_len; - let s = (idx / params.context_len) % params.S; - let h = idx / (params.context_len * params.S); + let h = gid.x / tiles; + let rem = gid.x % tiles; + let row_tile = rem / nct; + let col_tile = rem % nct; let kvh = h / params.g; + let s0 = row_tile * TM; + let c0 = col_tile * TN; - let q_base = s * params.Hq * params.D + h * params.D; - let k_base = c * params.Hkv * params.D + kvh * params.D; + var acc: array, 4>; + acc[0] = vec4(0.0, 0.0, 0.0, 0.0); + acc[1] = vec4(0.0, 0.0, 0.0, 0.0); + acc[2] = vec4(0.0, 0.0, 0.0, 0.0); + acc[3] = vec4(0.0, 0.0, 0.0, 0.0); - var acc: f32 = 0.0; - var d: u32 = 0u; + var d4: u32 = 0u; loop { - if (d >= params.D) { + if (d4 >= params.D) { break; } - acc = acc + t_q[q_base + d] * t_k_cache[k_base + d]; - d = d + 1u; + let q0 = load_q_vec4(s0 + 0u, h, d4); + let q1 = load_q_vec4(s0 + 1u, h, d4); + let q2 = load_q_vec4(s0 + 2u, h, d4); + let q3 = load_q_vec4(s0 + 3u, h, d4); + let k0 = load_k_vec4(c0 + 0u, kvh, d4); + let k1 = load_k_vec4(c0 + 1u, kvh, d4); + let k2 = load_k_vec4(c0 + 2u, kvh, d4); + let k3 = load_k_vec4(c0 + 3u, kvh, d4); + acc[0] += vec4(dot(q0, k0), dot(q0, k1), dot(q0, k2), dot(q0, k3)); + acc[1] += vec4(dot(q1, k0), dot(q1, k1), dot(q1, k2), dot(q1, k3)); + acc[2] += vec4(dot(q2, k0), dot(q2, k1), dot(q2, k2), dot(q2, k3)); + acc[3] += vec4(dot(q3, k0), dot(q3, k1), dot(q3, k2), dot(q3, k3)); + d4 = d4 + 4u; } - acc = acc * params.scale; - // Causal mask: position c may not attend beyond s + input_pos. - if (c > s + params.input_pos) { - acc = NEG_INF; + var m: u32 = 0u; + loop { + if (m >= TM) { + break; + } + let av = acc[m]; + store_qk(s0 + m, c0 + 0u, h, av.x); + store_qk(s0 + m, c0 + 1u, h, av.y); + store_qk(s0 + m, c0 + 2u, h, av.z); + store_qk(s0 + m, c0 + 3u, h, av.w); + m = m + 1u; } - - t_attn_weights[idx] = acc; } diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h index 3f3f3d6b085..e3b703aeed1 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_attn_weights.wgsl - DO NOT EDIT. -// wgsl-sha256: 7410869c1c35f09777851bf49b835dc8fecaff3f327aa64a9c900ac0cc3445e1 +// wgsl-sha256: fabbf7f1dcbcac85bb2798ed23c904061d7629afd1abce19c290c81e5f54a47c inline constexpr const char* kSdpaComputeAttnWeightsWGSL = R"( @group(0) @binding(0) var t_attn_weights: array; @group(0) @binding(1) var t_q: array; @@ -36,39 +36,104 @@ const NEG_INF: f32 = -1.0e30; override wg_size: u32 = 64; +const TM: u32 = 4u; +const TN: u32 = 4u; + +fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (s >= params.S) { + return r; + } + let base = s * params.Hq * params.D + h * params.D; + if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; } + if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; } + if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; } + if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; } + return r; +} + +fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (c >= params.context_len) { + return r; + } + let base = c * params.Hkv * params.D + kvh * params.D; + if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; } + if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; } + if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; } + if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; } + return r; +} + +fn store_qk(s: u32, c: u32, h: u32, raw: f32) { + if (s >= params.S || c >= params.context_len) { + return; + } + var val = raw * params.scale; + // Causal mask: position c may not attend beyond s + input_pos. + if (c > s + params.input_pos) { + val = NEG_INF; + } + let idx = h * params.S * params.context_len + s * params.context_len + c; + t_attn_weights[idx] = val; +} + @compute @workgroup_size(wg_size, 1, 1) fn main(@builtin(global_invocation_id) gid: vec3) { - let total = params.Hq * params.S * params.context_len; - let idx = gid.x; - if (idx >= total) { + let nrt = (params.S + TM - 1u) / TM; + let nct = (params.context_len + TN - 1u) / TN; + let tiles = nrt * nct; + let total = tiles * params.Hq; + if (gid.x >= total) { return; } - let c = idx % params.context_len; - let s = (idx / params.context_len) % params.S; - let h = idx / (params.context_len * params.S); + let h = gid.x / tiles; + let rem = gid.x % tiles; + let row_tile = rem / nct; + let col_tile = rem % nct; let kvh = h / params.g; + let s0 = row_tile * TM; + let c0 = col_tile * TN; - let q_base = s * params.Hq * params.D + h * params.D; - let k_base = c * params.Hkv * params.D + kvh * params.D; + var acc: array, 4>; + acc[0] = vec4(0.0, 0.0, 0.0, 0.0); + acc[1] = vec4(0.0, 0.0, 0.0, 0.0); + acc[2] = vec4(0.0, 0.0, 0.0, 0.0); + acc[3] = vec4(0.0, 0.0, 0.0, 0.0); - var acc: f32 = 0.0; - var d: u32 = 0u; + var d4: u32 = 0u; loop { - if (d >= params.D) { + if (d4 >= params.D) { break; } - acc = acc + t_q[q_base + d] * t_k_cache[k_base + d]; - d = d + 1u; + let q0 = load_q_vec4(s0 + 0u, h, d4); + let q1 = load_q_vec4(s0 + 1u, h, d4); + let q2 = load_q_vec4(s0 + 2u, h, d4); + let q3 = load_q_vec4(s0 + 3u, h, d4); + let k0 = load_k_vec4(c0 + 0u, kvh, d4); + let k1 = load_k_vec4(c0 + 1u, kvh, d4); + let k2 = load_k_vec4(c0 + 2u, kvh, d4); + let k3 = load_k_vec4(c0 + 3u, kvh, d4); + acc[0] += vec4(dot(q0, k0), dot(q0, k1), dot(q0, k2), dot(q0, k3)); + acc[1] += vec4(dot(q1, k0), dot(q1, k1), dot(q1, k2), dot(q1, k3)); + acc[2] += vec4(dot(q2, k0), dot(q2, k1), dot(q2, k2), dot(q2, k3)); + acc[3] += vec4(dot(q3, k0), dot(q3, k1), dot(q3, k2), dot(q3, k3)); + d4 = d4 + 4u; } - acc = acc * params.scale; - // Causal mask: position c may not attend beyond s + input_pos. - if (c > s + params.input_pos) { - acc = NEG_INF; + var m: u32 = 0u; + loop { + if (m >= TM) { + break; + } + let av = acc[m]; + store_qk(s0 + m, c0 + 0u, h, av.x); + store_qk(s0 + m, c0 + 1u, h, av.y); + store_qk(s0 + m, c0 + 2u, h, av.z); + store_qk(s0 + m, c0 + 3u, h, av.w); + m = m + 1u; } - - t_attn_weights[idx] = acc; } )"; diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl index 97642670f60..3ac2339376e 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl @@ -16,31 +16,98 @@ struct Params { override wg_size: u32 = 64; +const TM: u32 = 4u; +const TN: u32 = 4u; + +fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (s >= params.S) { + return r; + } + let base = h * params.S * params.context_len + s * params.context_len; + if (c4 + 0u < params.context_len) { r.x = t_attn_weights_softmax[base + c4 + 0u]; } + if (c4 + 1u < params.context_len) { r.y = t_attn_weights_softmax[base + c4 + 1u]; } + if (c4 + 2u < params.context_len) { r.z = t_attn_weights_softmax[base + c4 + 2u]; } + if (c4 + 3u < params.context_len) { r.w = t_attn_weights_softmax[base + c4 + 3u]; } + return r; +} + +fn load_v_vec4(d: u32, kvh: u32, c4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (d >= params.D) { + return r; + } + let stride = params.Hkv * params.D; + let off = kvh * params.D + d; + if (c4 + 0u < params.context_len) { r.x = t_v_cache[(c4 + 0u) * stride + off]; } + if (c4 + 1u < params.context_len) { r.y = t_v_cache[(c4 + 1u) * stride + off]; } + if (c4 + 2u < params.context_len) { r.z = t_v_cache[(c4 + 2u) * stride + off]; } + if (c4 + 3u < params.context_len) { r.w = t_v_cache[(c4 + 3u) * stride + off]; } + return r; +} + +fn store_out(s: u32, d: u32, h: u32, val: f32) { + if (s >= params.S || d >= params.D) { + return; + } + let idx = s * params.Hq * params.D + h * params.D + d; + t_out[idx] = val; +} + @compute @workgroup_size(wg_size, 1, 1) fn main(@builtin(global_invocation_id) gid: vec3) { - let total = params.S * params.Hq * params.D; - let idx = gid.x; - if (idx >= total) { + let nrt = (params.S + TM - 1u) / TM; + let nct = (params.D + TN - 1u) / TN; + let tiles = nrt * nct; + let total = tiles * params.Hq; + if (gid.x >= total) { return; } - let d = idx % params.D; - let h = (idx / params.D) % params.Hq; - let s = idx / (params.D * params.Hq); + let h = gid.x / tiles; + let rem = gid.x % tiles; + let row_tile = rem / nct; + let col_tile = rem % nct; let kvh = h / params.g; + let s0 = row_tile * TM; + let d0 = col_tile * TN; - let aw_base = h * params.S * params.context_len + s * params.context_len; + var acc: array, 4>; + acc[0] = vec4(0.0, 0.0, 0.0, 0.0); + acc[1] = vec4(0.0, 0.0, 0.0, 0.0); + acc[2] = vec4(0.0, 0.0, 0.0, 0.0); + acc[3] = vec4(0.0, 0.0, 0.0, 0.0); - var acc: f32 = 0.0; - var c: u32 = 0u; + var c4: u32 = 0u; loop { - if (c >= params.context_len) { + if (c4 >= params.context_len) { break; } - let v_off = c * params.Hkv * params.D + kvh * params.D + d; - acc = acc + t_attn_weights_softmax[aw_base + c] * t_v_cache[v_off]; - c = c + 1u; + let a0 = load_a_vec4(s0 + 0u, h, c4); + let a1 = load_a_vec4(s0 + 1u, h, c4); + let a2 = load_a_vec4(s0 + 2u, h, c4); + let a3 = load_a_vec4(s0 + 3u, h, c4); + let v0 = load_v_vec4(d0 + 0u, kvh, c4); + let v1 = load_v_vec4(d0 + 1u, kvh, c4); + let v2 = load_v_vec4(d0 + 2u, kvh, c4); + let v3 = load_v_vec4(d0 + 3u, kvh, c4); + acc[0] += vec4(dot(a0, v0), dot(a0, v1), dot(a0, v2), dot(a0, v3)); + acc[1] += vec4(dot(a1, v0), dot(a1, v1), dot(a1, v2), dot(a1, v3)); + acc[2] += vec4(dot(a2, v0), dot(a2, v1), dot(a2, v2), dot(a2, v3)); + acc[3] += vec4(dot(a3, v0), dot(a3, v1), dot(a3, v2), dot(a3, v3)); + c4 = c4 + 4u; } - t_out[idx] = acc; + var m: u32 = 0u; + loop { + if (m >= TM) { + break; + } + let ov = acc[m]; + store_out(s0 + m, d0 + 0u, h, ov.x); + store_out(s0 + m, d0 + 1u, h, ov.y); + store_out(s0 + m, d0 + 2u, h, ov.z); + store_out(s0 + m, d0 + 3u, h, ov.w); + m = m + 1u; + } } diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h index ce25df06876..cf1d742d7e5 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_out.wgsl - DO NOT EDIT. -// wgsl-sha256: 67b9c64fbffdcb72264dda42e24b59e414719411c64c504f84f2ba57b5dcfc0f +// wgsl-sha256: 4ffc13bad0bf56b87a57f75307f29e851dd2bd6bf0dba094488df5d262e910e3 inline constexpr const char* kSdpaComputeOutWGSL = R"( @group(0) @binding(0) var t_out: array; @group(0) @binding(1) var t_attn_weights_softmax: array; @@ -33,33 +33,100 @@ struct Params { override wg_size: u32 = 64; +const TM: u32 = 4u; +const TN: u32 = 4u; + +fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (s >= params.S) { + return r; + } + let base = h * params.S * params.context_len + s * params.context_len; + if (c4 + 0u < params.context_len) { r.x = t_attn_weights_softmax[base + c4 + 0u]; } + if (c4 + 1u < params.context_len) { r.y = t_attn_weights_softmax[base + c4 + 1u]; } + if (c4 + 2u < params.context_len) { r.z = t_attn_weights_softmax[base + c4 + 2u]; } + if (c4 + 3u < params.context_len) { r.w = t_attn_weights_softmax[base + c4 + 3u]; } + return r; +} + +fn load_v_vec4(d: u32, kvh: u32, c4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (d >= params.D) { + return r; + } + let stride = params.Hkv * params.D; + let off = kvh * params.D + d; + if (c4 + 0u < params.context_len) { r.x = t_v_cache[(c4 + 0u) * stride + off]; } + if (c4 + 1u < params.context_len) { r.y = t_v_cache[(c4 + 1u) * stride + off]; } + if (c4 + 2u < params.context_len) { r.z = t_v_cache[(c4 + 2u) * stride + off]; } + if (c4 + 3u < params.context_len) { r.w = t_v_cache[(c4 + 3u) * stride + off]; } + return r; +} + +fn store_out(s: u32, d: u32, h: u32, val: f32) { + if (s >= params.S || d >= params.D) { + return; + } + let idx = s * params.Hq * params.D + h * params.D + d; + t_out[idx] = val; +} + @compute @workgroup_size(wg_size, 1, 1) fn main(@builtin(global_invocation_id) gid: vec3) { - let total = params.S * params.Hq * params.D; - let idx = gid.x; - if (idx >= total) { + let nrt = (params.S + TM - 1u) / TM; + let nct = (params.D + TN - 1u) / TN; + let tiles = nrt * nct; + let total = tiles * params.Hq; + if (gid.x >= total) { return; } - let d = idx % params.D; - let h = (idx / params.D) % params.Hq; - let s = idx / (params.D * params.Hq); + let h = gid.x / tiles; + let rem = gid.x % tiles; + let row_tile = rem / nct; + let col_tile = rem % nct; let kvh = h / params.g; + let s0 = row_tile * TM; + let d0 = col_tile * TN; - let aw_base = h * params.S * params.context_len + s * params.context_len; + var acc: array, 4>; + acc[0] = vec4(0.0, 0.0, 0.0, 0.0); + acc[1] = vec4(0.0, 0.0, 0.0, 0.0); + acc[2] = vec4(0.0, 0.0, 0.0, 0.0); + acc[3] = vec4(0.0, 0.0, 0.0, 0.0); - var acc: f32 = 0.0; - var c: u32 = 0u; + var c4: u32 = 0u; loop { - if (c >= params.context_len) { + if (c4 >= params.context_len) { break; } - let v_off = c * params.Hkv * params.D + kvh * params.D + d; - acc = acc + t_attn_weights_softmax[aw_base + c] * t_v_cache[v_off]; - c = c + 1u; + let a0 = load_a_vec4(s0 + 0u, h, c4); + let a1 = load_a_vec4(s0 + 1u, h, c4); + let a2 = load_a_vec4(s0 + 2u, h, c4); + let a3 = load_a_vec4(s0 + 3u, h, c4); + let v0 = load_v_vec4(d0 + 0u, kvh, c4); + let v1 = load_v_vec4(d0 + 1u, kvh, c4); + let v2 = load_v_vec4(d0 + 2u, kvh, c4); + let v3 = load_v_vec4(d0 + 3u, kvh, c4); + acc[0] += vec4(dot(a0, v0), dot(a0, v1), dot(a0, v2), dot(a0, v3)); + acc[1] += vec4(dot(a1, v0), dot(a1, v1), dot(a1, v2), dot(a1, v3)); + acc[2] += vec4(dot(a2, v0), dot(a2, v1), dot(a2, v2), dot(a2, v3)); + acc[3] += vec4(dot(a3, v0), dot(a3, v1), dot(a3, v2), dot(a3, v3)); + c4 = c4 + 4u; } - t_out[idx] = acc; + var m: u32 = 0u; + loop { + if (m >= TM) { + break; + } + let ov = acc[m]; + store_out(s0 + m, d0 + 0u, h, ov.x); + store_out(s0 + m, d0 + 1u, h, ov.y); + store_out(s0 + m, d0 + 2u, h, ov.z); + store_out(s0 + m, d0 + 3u, h, ov.w); + m = m + 1u; + } } )";