From f561681ecf09dd8db648035fd383d4a6ce99bf97 Mon Sep 17 00:00:00 2001 From: Minho Ryu Date: Fri, 1 May 2026 07:39:17 +0900 Subject: [PATCH 1/5] [Common, PyTorch] Add Triton MLA attention kernels for SM80 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit FlashAttention-2 style Triton kernels for MLA-shaped attention (head_dim_qk != head_dim_v, e.g. DeepSeek-V2 192/128) targeted at SM80 (A100), where FlashMLA / FA4-MLA SM80 paths are not available. Three kernel families in transformer_engine/common/triton/mla.py: - Prefill / training forward: standard FA-2 online softmax adapted for non-square head dims, right-aligned causal, autotuned tile sizes. Saves an fp32 LSE for backward. - Analytical backward: canonical FA-2 three-pass structure (preprocess for Delta = rowsum(O * dO), then dQ over Q-tile programs and dK/dV over K/V-tile programs). No atomics — each program owns a distinct output slice. - Decode forward over compressed KV cache: c_kv [B, S_kv, R] and k_rope [B, S_kv, R_rope] with absorbed up-projection (Q's nope side pre-multiplied by W_uk^T). c_kv plays both K (nope side) and V; per-head K/V are never materialized. Returns O_inter [B, H, S_q, R]; the caller applies W_uv. PyTorch wrapper (transformer_engine/pytorch/triton/mla.py): - mla_attention(q, k, v, *, softmax_scale, is_causal, qkv_format) via torch.autograd.Function (Triton fwd + Triton bwd). Layouts: bshd, bhsd, sbhd. Pure-PyTorch mla_attention_ref kept as the test reference. - mla_decode_attention(q_nope_abs, q_rope, c_kv, k_rope, *, softmax_scale, is_causal). softmax_scale is required (R is not the original head_dim_qk so no sane default). Optional DotProductAttention hookup behind NVTE_MLA_TRITON=1 (default off). MLATritonAttention added to backends.py; DotProductAttention.forward gains a strict-precondition early-out that falls through to the regular FA / Fused / Unfused cascade unless every supported feature flag matches (no FP8, no dropout, no context parallel, no alibi/bias, no padding/sliding-window mask, no inference cache, bshd/sbhd, MLA-shaped, SM80+). The existing get_attention_backend() return signature is preserved, so existing dispatch is untouched when the env var is unset. Tests at tests/pytorch/test_mla_triton.py exercise: - Prefill forward across {bf16, fp16} x {causal, non-causal} x {bshd, bhsd, sbhd} for shapes including DeepSeek-V2 prefill, cross-attention (S_q != S_kv), and non-multiple-of-block seqlens. - Backward dQ/dK/dV vs fp32 PyTorch reference within bf16/fp16 tolerances. - Decode forward across DeepSeek-V2 dims (R=512, R_rope=64) and smoke shapes; plus dim-mismatch and dtype rejection. - DPA dispatch: equality with direct mla_attention call when NVTE_MLA_TRITON=1, and fall-through preservation when unset. Signed-off-by: Minho Ryu --- tests/pytorch/test_mla_triton.py | 346 +++++++++ transformer_engine/common/triton/mla.py | 733 ++++++++++++++++++ .../dot_product_attention/backends.py | 41 + .../dot_product_attention.py | 47 ++ transformer_engine/pytorch/triton/__init__.py | 1 + transformer_engine/pytorch/triton/mla.py | 540 +++++++++++++ 6 files changed, 1708 insertions(+) create mode 100644 tests/pytorch/test_mla_triton.py create mode 100644 transformer_engine/common/triton/mla.py create mode 100644 transformer_engine/pytorch/triton/mla.py diff --git a/tests/pytorch/test_mla_triton.py b/tests/pytorch/test_mla_triton.py new file mode 100644 index 0000000000..2aa9c4af8d --- /dev/null +++ b/tests/pytorch/test_mla_triton.py @@ -0,0 +1,346 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for the Triton-based MLA kernels (transformer_engine.pytorch.triton.mla).""" + +from dataclasses import dataclass + +import pytest +import torch + +from utils import reset_rng_states +from transformer_engine.pytorch.triton.mla import ( + mla_attention, + mla_attention_ref, + mla_decode_attention, + mla_decode_attention_ref, +) + + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 8, + reason="MLA Triton kernel requires SM80+ (A100 or newer).", +) + +# Disable TF32 to keep the fp32 reference path bit-comparable across runs. +torch.backends.cuda.matmul.allow_tf32 = False + + +@dataclass +class MLAConfig: + b: int + h: int + s_q: int + s_kv: int + d_qk: int + d_v: int + + @staticmethod + def desc(cfg): + return ( + f"b{cfg.b}_h{cfg.h}_sq{cfg.s_q}_skv{cfg.s_kv}" + f"_dqk{cfg.d_qk}_dv{cfg.d_v}" + ) + + +mla_configs = [ + # square head dim sanity + MLAConfig(2, 4, 128, 128, 64, 64), + MLAConfig(2, 4, 256, 256, 128, 128), + # DeepSeek-V2 prefill shape + MLAConfig(2, 8, 512, 512, 192, 128), + # bigger seq + MLAConfig(1, 16, 2048, 2048, 192, 128), + # cross-attention-style: S_q != S_kv + MLAConfig(2, 8, 512, 1024, 192, 128), + # non-multiple-of-block seq lengths + MLAConfig(2, 4, 513, 513, 192, 128), + # decode-shaped Q (still using the prefill kernel) + MLAConfig(1, 4, 1, 1024, 192, 128), +] + + +def _tols(dtype): + if dtype == torch.bfloat16: + return dict(atol=2.5e-2, rtol=2.5e-2) + return dict(atol=5e-3, rtol=5e-3) + + +def _make_qkv(cfg, dtype, qkv_format): + """Allocate (q, k, v) in the requested layout. Sized so values stay well within the + kernel's dynamic range (avoids saturating exp in fp32 accumulator).""" + scale = 0.5 + if qkv_format == "bhsd": + q = torch.randn(cfg.b, cfg.h, cfg.s_q, cfg.d_qk, device="cuda", dtype=dtype) * scale + k = torch.randn(cfg.b, cfg.h, cfg.s_kv, cfg.d_qk, device="cuda", dtype=dtype) * scale + v = torch.randn(cfg.b, cfg.h, cfg.s_kv, cfg.d_v, device="cuda", dtype=dtype) * scale + elif qkv_format == "bshd": + q = torch.randn(cfg.b, cfg.s_q, cfg.h, cfg.d_qk, device="cuda", dtype=dtype) * scale + k = torch.randn(cfg.b, cfg.s_kv, cfg.h, cfg.d_qk, device="cuda", dtype=dtype) * scale + v = torch.randn(cfg.b, cfg.s_kv, cfg.h, cfg.d_v, device="cuda", dtype=dtype) * scale + elif qkv_format == "sbhd": + q = torch.randn(cfg.s_q, cfg.b, cfg.h, cfg.d_qk, device="cuda", dtype=dtype) * scale + k = torch.randn(cfg.s_kv, cfg.b, cfg.h, cfg.d_qk, device="cuda", dtype=dtype) * scale + v = torch.randn(cfg.s_kv, cfg.b, cfg.h, cfg.d_v, device="cuda", dtype=dtype) * scale + else: + raise ValueError(qkv_format) + return q, k, v + + +def _ref_in_user_layout(q, k, v, qkv_format, *, softmax_scale=None, is_causal=False): + """Run mla_attention_ref on tensors that are in user layout. + + The reference operates in BHSD; we transpose, run, and transpose back. + """ + if qkv_format == "bshd": + q_b = q.transpose(1, 2).contiguous() + k_b = k.transpose(1, 2).contiguous() + v_b = v.transpose(1, 2).contiguous() + elif qkv_format == "sbhd": + q_b = q.permute(1, 2, 0, 3).contiguous() + k_b = k.permute(1, 2, 0, 3).contiguous() + v_b = v.permute(1, 2, 0, 3).contiguous() + else: + q_b, k_b, v_b = q, k, v + out_bhsd = mla_attention_ref(q_b, k_b, v_b, softmax_scale=softmax_scale, is_causal=is_causal) + if qkv_format == "bshd": + return out_bhsd.transpose(1, 2).contiguous() + if qkv_format == "sbhd": + return out_bhsd.permute(2, 0, 1, 3).contiguous() + return out_bhsd + + +@pytest.mark.parametrize("cfg", mla_configs, ids=MLAConfig.desc) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +@pytest.mark.parametrize("is_causal", [False, True], ids=["nocausal", "causal"]) +@pytest.mark.parametrize("qkv_format", ["bshd", "bhsd", "sbhd"]) +def test_mla_forward(cfg, dtype, is_causal, qkv_format): + reset_rng_states() + q, k, v = _make_qkv(cfg, dtype, qkv_format) + + out_triton = mla_attention(q, k, v, is_causal=is_causal, qkv_format=qkv_format) + out_ref = _ref_in_user_layout(q, k, v, qkv_format, is_causal=is_causal) + + assert out_triton.shape == out_ref.shape + assert out_triton.dtype == out_ref.dtype + torch.testing.assert_close(out_triton, out_ref, **_tols(dtype)) + + +_backward_configs = [ + MLAConfig(2, 4, 128, 128, 64, 64), + MLAConfig(2, 4, 256, 256, 128, 128), + MLAConfig(2, 8, 256, 256, 192, 128), # DeepSeek-V2 dims, smaller seq + MLAConfig(2, 8, 512, 512, 192, 128), # DeepSeek-V2 prefill + MLAConfig(2, 4, 256, 512, 192, 128), # cross-attn shape + MLAConfig(2, 4, 257, 257, 192, 128), # non-multiple-of-block seq +] + + +@pytest.mark.parametrize("cfg", _backward_configs, ids=MLAConfig.desc) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +@pytest.mark.parametrize("is_causal", [False, True], ids=["nocausal", "causal"]) +def test_mla_backward_matches_reference(cfg, dtype, is_causal): + """Triton-computed dQ/dK/dV must match the pure-PyTorch reference within bf16/fp16 tolerances.""" + reset_rng_states() + q_base, k_base, v_base = _make_qkv(cfg, dtype, qkv_format="bshd") + q = q_base.detach().clone().requires_grad_(True) + k = k_base.detach().clone().requires_grad_(True) + v = v_base.detach().clone().requires_grad_(True) + q_ref = q_base.detach().clone().requires_grad_(True) + k_ref = k_base.detach().clone().requires_grad_(True) + v_ref = v_base.detach().clone().requires_grad_(True) + + out_triton = mla_attention(q, k, v, is_causal=is_causal, qkv_format="bshd") + out_ref = _ref_in_user_layout(q_ref, k_ref, v_ref, "bshd", is_causal=is_causal) + + # Smaller-magnitude grad_o so dS = P*(dP - Delta) stays well within the + # accumulator's representable range for the non-square head-dim cases. + grad_o = torch.randn_like(out_triton) * 0.1 + out_triton.backward(grad_o) + out_ref.backward(grad_o) + + tols = _tols(dtype) + torch.testing.assert_close(q.grad, q_ref.grad, **tols) + torch.testing.assert_close(k.grad, k_ref.grad, **tols) + torch.testing.assert_close(v.grad, v_ref.grad, **tols) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +@pytest.mark.parametrize("is_causal", [False, True], ids=["nocausal", "causal"]) +def test_mla_layout_equivalence(dtype, is_causal): + """Same logical inputs in BSHD and BHSD must produce equal outputs (modulo layout).""" + reset_rng_states() + cfg = MLAConfig(2, 4, 256, 256, 192, 128) + q_b, k_b, v_b = _make_qkv(cfg, dtype, qkv_format="bhsd") + q_s = q_b.transpose(1, 2).contiguous() + k_s = k_b.transpose(1, 2).contiguous() + v_s = v_b.transpose(1, 2).contiguous() + + out_bhsd = mla_attention(q_b, k_b, v_b, is_causal=is_causal, qkv_format="bhsd") + out_bshd = mla_attention(q_s, k_s, v_s, is_causal=is_causal, qkv_format="bshd") + + # Triton kernel runs on the canonicalized BHSD path in both cases, so outputs + # must match exactly after layout permutation. + torch.testing.assert_close(out_bhsd, out_bshd.transpose(1, 2).contiguous(), atol=0.0, rtol=0.0) + + +def test_mla_softmax_scale_default(): + """Verify the default softmax_scale is 1/sqrt(head_dim_qk) (Q-side, not V-side).""" + reset_rng_states() + cfg = MLAConfig(1, 2, 64, 64, 192, 128) + q, k, v = _make_qkv(cfg, torch.bfloat16, "bshd") + out_default = mla_attention(q, k, v, qkv_format="bshd") + out_explicit = mla_attention( + q, k, v, softmax_scale=cfg.d_qk**-0.5, qkv_format="bshd" + ) + torch.testing.assert_close(out_default, out_explicit, atol=0.0, rtol=0.0) + + +def test_mla_rejects_fp32_input(): + """The Triton kernel is fp16/bf16 only — fp32 inputs should raise.""" + cfg = MLAConfig(1, 2, 64, 64, 64, 64) + q, k, v = _make_qkv(cfg, torch.float32, "bshd") + with pytest.raises(ValueError, match="fp16 or bf16"): + mla_attention(q, k, v, qkv_format="bshd") + + +# --------------------------------------------------------------------------- +# Decode (absorbed up-projection) tests +# --------------------------------------------------------------------------- + + +@dataclass +class MLADecodeConfig: + b: int + h: int + s_q: int + s_kv: int + r: int # kv_lora_rank + r_rope: int + + @staticmethod + def desc(cfg): + return ( + f"b{cfg.b}_h{cfg.h}_sq{cfg.s_q}_skv{cfg.s_kv}" + f"_r{cfg.r}_rrope{cfg.r_rope}" + ) + + +_decode_configs = [ + MLADecodeConfig(1, 4, 1, 128, 64, 16), # smoke + MLADecodeConfig(1, 4, 1, 512, 64, 16), # bigger Skv + MLADecodeConfig(1, 4, 4, 512, 64, 16), # multi-token / speculative decode + MLADecodeConfig(2, 4, 128, 128, 128, 32), # prefill via decode kernel + # DeepSeek-V2 dims (kv_lora_rank=512, rope_dim=64). + MLADecodeConfig(1, 8, 1, 1024, 512, 64), + MLADecodeConfig(1, 16, 1, 2048, 512, 64), + MLADecodeConfig(1, 4, 1, 257, 512, 64), # non-multiple-of-block S_kv +] + + +def _make_decode_inputs(cfg, dtype, scale=0.5): + qn = torch.randn(cfg.b, cfg.h, cfg.s_q, cfg.r, device="cuda", dtype=dtype) * scale + qr = torch.randn(cfg.b, cfg.h, cfg.s_q, cfg.r_rope, device="cuda", dtype=dtype) * scale + ck = torch.randn(cfg.b, cfg.s_kv, cfg.r, device="cuda", dtype=dtype) * scale + kr = torch.randn(cfg.b, cfg.s_kv, cfg.r_rope, device="cuda", dtype=dtype) * scale + return qn, qr, ck, kr + + +@pytest.mark.parametrize("cfg", _decode_configs, ids=MLADecodeConfig.desc) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +@pytest.mark.parametrize("is_causal", [False, True], ids=["nocausal", "causal"]) +def test_mla_decode_forward(cfg, dtype, is_causal): + reset_rng_states() + qn, qr, ck, kr = _make_decode_inputs(cfg, dtype) + softmax_scale = 1.0 / (cfg.r + cfg.r_rope) ** 0.5 + + o = mla_decode_attention(qn, qr, ck, kr, softmax_scale=softmax_scale, is_causal=is_causal) + o_ref = mla_decode_attention_ref(qn, qr, ck, kr, softmax_scale=softmax_scale, is_causal=is_causal) + + assert o.shape == o_ref.shape == (cfg.b, cfg.h, cfg.s_q, cfg.r) + assert o.dtype == o_ref.dtype == dtype + torch.testing.assert_close(o, o_ref, **_tols(dtype)) + + +def test_mla_decode_rejects_fp32_input(): + cfg = MLADecodeConfig(1, 2, 1, 64, 64, 16) + qn, qr, ck, kr = _make_decode_inputs(cfg, torch.float32) + with pytest.raises(ValueError, match="fp16 or bf16"): + mla_decode_attention(qn, qr, ck, kr, softmax_scale=0.1) + + +def test_mla_decode_rejects_dim_mismatch(): + cfg = MLADecodeConfig(1, 2, 1, 64, 64, 16) + qn, qr, ck, kr = _make_decode_inputs(cfg, torch.bfloat16) + # Mismatched kv_lora_rank between Q-side and cache. + bad_ck = torch.randn(1, 64, 32, device="cuda", dtype=torch.bfloat16) + with pytest.raises(ValueError, match="kv_lora_rank"): + mla_decode_attention(qn, qr, bad_ck, kr, softmax_scale=0.1) + + +# --------------------------------------------------------------------------- +# DotProductAttention dispatch (NVTE_MLA_TRITON=1 fast path) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +@pytest.mark.parametrize("is_causal", [False, True], ids=["nocausal", "causal"]) +@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd"]) +def test_dpa_dispatches_to_mla_triton(monkeypatch, dtype, is_causal, qkv_format): + """With NVTE_MLA_TRITON=1, an MLA-shaped DotProductAttention call must + produce the same output as a direct ``mla_attention`` call.""" + import transformer_engine.pytorch as te + + reset_rng_states() + cfg = MLAConfig(2, 4, 256, 256, 192, 128) + q, k, v = _make_qkv(cfg, dtype, qkv_format=qkv_format) + + monkeypatch.setenv("NVTE_MLA_TRITON", "1") + + # Force backend cache invalidation so the env-var check re-runs. + from transformer_engine.pytorch.attention.dot_product_attention import dot_product_attention as dpa_mod + dpa_mod._attention_backends["backend_selection_requires_update"] = True + + softmax_scale = cfg.d_qk ** -0.5 + dpa = te.DotProductAttention( + num_attention_heads=cfg.h, + kv_channels=(cfg.d_qk, cfg.d_v), + attention_dropout=0.0, + softmax_scale=softmax_scale, + qkv_format=qkv_format, + attn_mask_type="causal" if is_causal else "no_mask", + ).cuda() + + out_dpa = dpa(q, k, v) + out_direct = mla_attention( + q, k, v, softmax_scale=softmax_scale, is_causal=is_causal, qkv_format=qkv_format + ) + torch.testing.assert_close(out_dpa, out_direct, atol=0.0, rtol=0.0) + + +def test_dpa_falls_through_when_env_var_unset(monkeypatch): + """With NVTE_MLA_TRITON unset, dispatch must NOT route through the MLA + Triton backend even for MLA-shaped inputs (existing behavior preserved).""" + import transformer_engine.pytorch as te + from transformer_engine.pytorch.attention.dot_product_attention import dot_product_attention as dpa_mod + + reset_rng_states() + cfg = MLAConfig(2, 4, 256, 256, 192, 128) + q, k, v = _make_qkv(cfg, torch.bfloat16, qkv_format="bshd") + + monkeypatch.delenv("NVTE_MLA_TRITON", raising=False) + dpa_mod._attention_backends["backend_selection_requires_update"] = True + + dpa = te.DotProductAttention( + num_attention_heads=cfg.h, + kv_channels=(cfg.d_qk, cfg.d_v), + attention_dropout=0.0, + softmax_scale=cfg.d_qk ** -0.5, + qkv_format="bshd", + attn_mask_type="causal", + ).cuda() + + # Just exercise it — any cuDNN/FA backend may be selected. We only assert + # no exception (i.e. the early-out wasn't accidentally triggered). + _ = dpa(q, k, v) diff --git a/transformer_engine/common/triton/mla.py b/transformer_engine/common/triton/mla.py new file mode 100644 index 0000000000..f4ec2a9ed0 --- /dev/null +++ b/transformer_engine/common/triton/mla.py @@ -0,0 +1,733 @@ +# pylint: disable=missing-function-docstring + +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Triton kernels for MLA (Multi-head Latent Attention). + +Two operating modes: + +1. **Prefill / training** (``_mla_attn_fwd`` + backward kernels): + FlashAttention-2 style with non-square head dimensions + (``head_dim_qk != head_dim_v`` as used by DeepSeek-V2/V3). Inputs are full + Q, K, V tensors. Backward uses the canonical FA-2 three-pass structure + (preprocess for ``Delta = rowsum(O * dO)``, then ``dQ`` and ``dK``/``dV`` + kernels). No atomics — each program owns a distinct output slice. + +2. **Absorbed-projection decode** (``_mla_decode_attn_fwd``): + Operates on a compressed KV cache (latent ``c_kv`` of dim ``kv_lora_rank`` + plus a decoupled rope key ``k_rope``). Q's nope side is pre-absorbed via + ``W_uk^T`` so the kernel never materializes per-head K or V; ``c_kv`` + serves as both the key (nope side) and the value. Output is the + pre-``W_uv`` intermediate ``O_inter`` of shape ``[B, H, S_q, kv_lora_rank]``; + the caller applies ``W_uv`` to produce the final attention output. + +Tuned for SM80 (A100). Compiles on SM89/SM90 too but is not specialized for them. +""" + +import itertools +import os + +import triton +import triton.language as tl + + +def _mla_fwd_configs(): + block_m = [64, 128] + block_n = [32, 64] + num_warps = [4, 8] + num_stages = [2, 3] + + configs = [] + for m, n, w, s in itertools.product(block_m, block_n, num_warps, num_stages): + configs.append( + triton.Config( + {"BLOCK_M": m, "BLOCK_N": n}, + num_warps=w, + num_stages=s, + ) + ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +@triton.autotune( + configs=_mla_fwd_configs(), + key=["S_q", "S_kv", "D_qk", "D_v", "IS_CAUSAL"], +) +@triton.jit +def _mla_attn_fwd( + Q_ptr, # (B, H, S_q, D_qk) + K_ptr, # (B, H, S_kv, D_qk) + V_ptr, # (B, H, S_kv, D_v) + O_ptr, # (B, H, S_q, D_v) + LSE_ptr, # (B, H, S_q) fp32 + softmax_scale, + B, + H, + S_q, + S_kv, + D_qk, + D_v, + sQ_b, + sQ_h, + sQ_s, + sQ_d: tl.constexpr, + sK_b, + sK_h, + sK_s, + sK_d: tl.constexpr, + sV_b, + sV_h, + sV_s, + sV_d: tl.constexpr, + sO_b, + sO_h, + sO_s, + sO_d: tl.constexpr, + sL_b, + sL_h, + sL_s: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, +): + """One program -> one BLOCK_M tile of Q rows for one (batch, head). + + Layout: BHSD. Right-aligned causal: row i attends to keys j s.t. + ``j <= i + (S_kv - S_q)``. + """ + pid_m = tl.program_id(0) + pid_bh = tl.program_id(1) + tl.assume(pid_m >= 0) + tl.assume(pid_bh >= 0) + + off_b = pid_bh // H + off_h = pid_bh % H + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n_init = tl.arange(0, BLOCK_N) + offs_d_qk = tl.arange(0, BLOCK_DMODEL_QK) + offs_d_v = tl.arange(0, BLOCK_DMODEL_V) + + mask_m = offs_m < S_q + mask_d_qk = offs_d_qk < D_qk + mask_d_v = offs_d_v < D_v + + # Load Q tile once (resident across the inner K/V loop). + q_base = Q_ptr + off_b * sQ_b + off_h * sQ_h + q_ptrs = q_base + offs_m[:, None] * sQ_s + offs_d_qk[None, :] * sQ_d + q = tl.load(q_ptrs, mask=mask_m[:, None] & mask_d_qk[None, :], other=0.0) + + k_base = K_ptr + off_b * sK_b + off_h * sK_h + v_base = V_ptr + off_b * sV_b + off_h * sV_h + + # ``m_i`` is initialized to a finite "very negative" sentinel rather than -inf + # so that exp(m_i - m_ij) is well-defined even when an entire K block is masked + # (every qk == -inf, e.g. the rare S_q > S_kv + causal case). -1e6 is large + # enough that exp(-1e6 - any_realistic_qk) underflows to 0 in fp32. + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - 1.0e6 + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=tl.float32) + + # Loop bound: standard FA-2 early exit for causal. No clamp to 0 needed — + # ``range`` with a non-positive upper bound is empty. + if IS_CAUSAL: + hi = tl.minimum((pid_m + 1) * BLOCK_M + (S_kv - S_q), S_kv) + else: + hi = S_kv + + for start_n in range(0, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + offs_n = start_n + offs_n_init + mask_n = offs_n < S_kv + + # Load K tile [BLOCK_N, BLOCK_DMODEL_QK]. + k_ptrs = k_base + offs_n[:, None] * sK_s + offs_d_qk[None, :] * sK_d + k = tl.load(k_ptrs, mask=mask_n[:, None] & mask_d_qk[None, :], other=0.0) + + # qk = Q @ K^T -> [BLOCK_M, BLOCK_N] (fp32 accum) + qk = tl.dot(q, tl.trans(k)) + qk = qk * softmax_scale + + # Mask invalid keys (K-tail past S_kv) to -inf BEFORE softmax max/exp. + qk = tl.where(mask_n[None, :], qk, float("-inf")) + if IS_CAUSAL: + causal_mask = offs_n[None, :] <= (offs_m[:, None] + (S_kv - S_q)) + qk = tl.where(causal_mask, qk, float("-inf")) + + # Online softmax. ``m_i`` is initialized to a finite sentinel so that + # exp(m_i - m_ij) is well-defined even when this entire block is masked + # (all qk == -inf). + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.exp(m_i - m_ij) + p = tl.exp(qk - m_ij[:, None]) + l_i = l_i * alpha + tl.sum(p, 1) + acc = acc * alpha[:, None] + + # Load V tile [BLOCK_N, BLOCK_DMODEL_V] and accumulate P @ V. + v_ptrs = v_base + offs_n[:, None] * sV_s + offs_d_v[None, :] * sV_d + v = tl.load(v_ptrs, mask=mask_n[:, None] & mask_d_v[None, :], other=0.0) + acc = tl.dot(p.to(v.dtype), v, acc) + + m_i = m_ij + + # Epilogue: normalize and store O in input dtype. + acc = acc / l_i[:, None] + o_base = O_ptr + off_b * sO_b + off_h * sO_h + o_ptrs = o_base + offs_m[:, None] * sO_s + offs_d_v[None, :] * sO_d + tl.store( + o_ptrs, + acc.to(O_ptr.dtype.element_ty), + mask=mask_m[:, None] & mask_d_v[None, :], + ) + + # Store fp32 LSE = m_i + log(l_i) (used by the analytical backward). + lse_base = LSE_ptr + off_b * sL_b + off_h * sL_h + lse_ptrs = lse_base + offs_m * sL_s + tl.store(lse_ptrs, m_i + tl.log(l_i), mask=mask_m) + + +# --------------------------------------------------------------------------- +# Backward kernels (FA-2 style, three passes, no atomics) +# --------------------------------------------------------------------------- + + +def _mla_bwd_configs(): + block_m = [64, 128] + block_n = [32, 64] + num_warps = [4, 8] + num_stages = [2, 3] + configs = [] + for m, n, w, s in itertools.product(block_m, block_n, num_warps, num_stages): + configs.append( + triton.Config( + {"BLOCK_M": m, "BLOCK_N": n}, + num_warps=w, + num_stages=s, + ) + ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +@triton.jit +def _mla_attn_bwd_preprocess( + O_ptr, # (B, H, S_q, D_v) + DO_ptr, # (B, H, S_q, D_v) + Delta_ptr, # (B, H, S_q) fp32 output + B, + H, + S_q, + D_v, + sO_b, + sO_h, + sO_s, + sO_d: tl.constexpr, + sDO_b, + sDO_h, + sDO_s, + sDO_d: tl.constexpr, + sD_b, + sD_h, + sD_s: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, +): + """Compute ``Delta[b,h,m] = sum_d O[b,h,m,d] * dO[b,h,m,d]`` in fp32.""" + pid_m = tl.program_id(0) + pid_bh = tl.program_id(1) + tl.assume(pid_m >= 0) + tl.assume(pid_bh >= 0) + + off_b = pid_bh // H + off_h = pid_bh % H + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_DMODEL_V) + mask_m = offs_m < S_q + mask_d = offs_d < D_v + mask_md = mask_m[:, None] & mask_d[None, :] + + o_ptrs = ( + O_ptr + + off_b * sO_b + + off_h * sO_h + + offs_m[:, None] * sO_s + + offs_d[None, :] * sO_d + ) + do_ptrs = ( + DO_ptr + + off_b * sDO_b + + off_h * sDO_h + + offs_m[:, None] * sDO_s + + offs_d[None, :] * sDO_d + ) + o = tl.load(o_ptrs, mask=mask_md, other=0.0).to(tl.float32) + do = tl.load(do_ptrs, mask=mask_md, other=0.0).to(tl.float32) + + delta = tl.sum(o * do, axis=1) + delta_ptrs = Delta_ptr + off_b * sD_b + off_h * sD_h + offs_m * sD_s + tl.store(delta_ptrs, delta, mask=mask_m) + + +@triton.autotune( + configs=_mla_bwd_configs(), + key=["S_q", "S_kv", "D_qk", "D_v", "IS_CAUSAL"], +) +@triton.jit +def _mla_attn_bwd_dq( + Q_ptr, # (B, H, S_q, D_qk) + K_ptr, # (B, H, S_kv, D_qk) + V_ptr, # (B, H, S_kv, D_v) + DO_ptr, # (B, H, S_q, D_v) + LSE_ptr, # (B, H, S_q) fp32 + Delta_ptr, # (B, H, S_q) fp32 + DQ_ptr, # (B, H, S_q, D_qk) output + softmax_scale, + B, + H, + S_q, + S_kv, + D_qk, + D_v, + sQ_b, + sQ_h, + sQ_s, + sQ_d: tl.constexpr, + sK_b, + sK_h, + sK_s, + sK_d: tl.constexpr, + sV_b, + sV_h, + sV_s, + sV_d: tl.constexpr, + sDO_b, + sDO_h, + sDO_s, + sDO_d: tl.constexpr, + sLSE_b, + sLSE_h, + sLSE_s: tl.constexpr, + sD_b, + sD_h, + sD_s: tl.constexpr, + sDQ_b, + sDQ_h, + sDQ_s, + sDQ_d: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, +): + """Compute ``dQ`` for one BLOCK_M row tile of one (batch, head).""" + pid_m = tl.program_id(0) + pid_bh = tl.program_id(1) + tl.assume(pid_m >= 0) + tl.assume(pid_bh >= 0) + + off_b = pid_bh // H + off_h = pid_bh % H + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n_init = tl.arange(0, BLOCK_N) + offs_d_qk = tl.arange(0, BLOCK_DMODEL_QK) + offs_d_v = tl.arange(0, BLOCK_DMODEL_V) + + mask_m = offs_m < S_q + mask_d_qk = offs_d_qk < D_qk + mask_d_v = offs_d_v < D_v + + # Load Q, dO, LSE, Delta — resident across the K/V loop. + q_base = Q_ptr + off_b * sQ_b + off_h * sQ_h + q_ptrs = q_base + offs_m[:, None] * sQ_s + offs_d_qk[None, :] * sQ_d + q = tl.load(q_ptrs, mask=mask_m[:, None] & mask_d_qk[None, :], other=0.0) + + do_base = DO_ptr + off_b * sDO_b + off_h * sDO_h + do_ptrs = do_base + offs_m[:, None] * sDO_s + offs_d_v[None, :] * sDO_d + do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d_v[None, :], other=0.0) + + # ``other=+inf`` so that exp(qk - lse) underflows to 0 for invalid Q rows + # and they contribute nothing to dQ. + lse_base = LSE_ptr + off_b * sLSE_b + off_h * sLSE_h + lse = tl.load(lse_base + offs_m * sLSE_s, mask=mask_m, other=float("inf")) + + delta_base = Delta_ptr + off_b * sD_b + off_h * sD_h + delta = tl.load(delta_base + offs_m * sD_s, mask=mask_m, other=0.0) + + dq = tl.zeros([BLOCK_M, BLOCK_DMODEL_QK], dtype=tl.float32) + + if IS_CAUSAL: + n_hi = tl.minimum((pid_m + 1) * BLOCK_M + (S_kv - S_q), S_kv) + else: + n_hi = S_kv + + k_base = K_ptr + off_b * sK_b + off_h * sK_h + v_base = V_ptr + off_b * sV_b + off_h * sV_h + + for start_n in range(0, n_hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + offs_n = start_n + offs_n_init + mask_n = offs_n < S_kv + + k_ptrs = k_base + offs_n[:, None] * sK_s + offs_d_qk[None, :] * sK_d + k = tl.load(k_ptrs, mask=mask_n[:, None] & mask_d_qk[None, :], other=0.0) + v_ptrs = v_base + offs_n[:, None] * sV_s + offs_d_v[None, :] * sV_d + v = tl.load(v_ptrs, mask=mask_n[:, None] & mask_d_v[None, :], other=0.0) + + # Recompute qk and apply same masks as the forward. + qk = tl.dot(q, tl.trans(k)) + qk = qk * softmax_scale + qk = tl.where(mask_n[None, :], qk, float("-inf")) + if IS_CAUSAL: + causal_mask = offs_n[None, :] <= (offs_m[:, None] + (S_kv - S_q)) + qk = tl.where(causal_mask, qk, float("-inf")) + + p = tl.exp(qk - lse[:, None]) + # dP = dO @ V^T -> [BLOCK_M, BLOCK_N] + dp = tl.dot(do, tl.trans(v)) + # dS = P * (dP - Delta) * scale + ds = (p * (dp - delta[:, None])) * softmax_scale + # dQ += dS @ K + dq += tl.dot(ds.to(k.dtype), k) + + dq_base = DQ_ptr + off_b * sDQ_b + off_h * sDQ_h + dq_ptrs = dq_base + offs_m[:, None] * sDQ_s + offs_d_qk[None, :] * sDQ_d + tl.store( + dq_ptrs, + dq.to(DQ_ptr.dtype.element_ty), + mask=mask_m[:, None] & mask_d_qk[None, :], + ) + + +@triton.autotune( + configs=_mla_bwd_configs(), + key=["S_q", "S_kv", "D_qk", "D_v", "IS_CAUSAL"], +) +@triton.jit +def _mla_attn_bwd_dkv( + Q_ptr, + K_ptr, + V_ptr, + DO_ptr, + LSE_ptr, + Delta_ptr, + DK_ptr, # (B, H, S_kv, D_qk) output + DV_ptr, # (B, H, S_kv, D_v) output + softmax_scale, + B, + H, + S_q, + S_kv, + D_qk, + D_v, + sQ_b, + sQ_h, + sQ_s, + sQ_d: tl.constexpr, + sK_b, + sK_h, + sK_s, + sK_d: tl.constexpr, + sV_b, + sV_h, + sV_s, + sV_d: tl.constexpr, + sDO_b, + sDO_h, + sDO_s, + sDO_d: tl.constexpr, + sLSE_b, + sLSE_h, + sLSE_s: tl.constexpr, + sD_b, + sD_h, + sD_s: tl.constexpr, + sDK_b, + sDK_h, + sDK_s, + sDK_d: tl.constexpr, + sDV_b, + sDV_h, + sDV_s, + sDV_d: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, +): + """Compute ``dK`` and ``dV`` for one BLOCK_N tile of one (batch, head).""" + pid_n = tl.program_id(0) + pid_bh = tl.program_id(1) + tl.assume(pid_n >= 0) + tl.assume(pid_bh >= 0) + + off_b = pid_bh // H + off_h = pid_bh % H + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m_init = tl.arange(0, BLOCK_M) + offs_d_qk = tl.arange(0, BLOCK_DMODEL_QK) + offs_d_v = tl.arange(0, BLOCK_DMODEL_V) + + mask_n = offs_n < S_kv + mask_d_qk = offs_d_qk < D_qk + mask_d_v = offs_d_v < D_v + + # Load K, V tiles — resident across the Q loop. + k_base = K_ptr + off_b * sK_b + off_h * sK_h + k_ptrs = k_base + offs_n[:, None] * sK_s + offs_d_qk[None, :] * sK_d + k = tl.load(k_ptrs, mask=mask_n[:, None] & mask_d_qk[None, :], other=0.0) + + v_base = V_ptr + off_b * sV_b + off_h * sV_h + v_ptrs = v_base + offs_n[:, None] * sV_s + offs_d_v[None, :] * sV_d + v = tl.load(v_ptrs, mask=mask_n[:, None] & mask_d_v[None, :], other=0.0) + + dk = tl.zeros([BLOCK_N, BLOCK_DMODEL_QK], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_DMODEL_V], dtype=tl.float32) + + # Causal: only Q rows i with j_max(i) >= n_start contribute. Round the + # lower bound *down* to a multiple of BLOCK_M so loads stay aligned; the + # extra rows iterated are masked out per-element. + if IS_CAUSAL: + m_lo = pid_n * BLOCK_N - (S_kv - S_q) + m_lo = tl.maximum(m_lo, 0) + m_lo = (m_lo // BLOCK_M) * BLOCK_M + else: + m_lo = 0 + + for start_m in range(m_lo, S_q, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m = start_m + offs_m_init + mask_m = offs_m < S_q + + q_base = Q_ptr + off_b * sQ_b + off_h * sQ_h + q_ptrs = q_base + offs_m[:, None] * sQ_s + offs_d_qk[None, :] * sQ_d + q = tl.load(q_ptrs, mask=mask_m[:, None] & mask_d_qk[None, :], other=0.0) + + # Recompute qk with the forward's mask convention. + qk = tl.dot(q, tl.trans(k)) + qk = qk * softmax_scale + qk = tl.where(mask_n[None, :], qk, float("-inf")) + if IS_CAUSAL: + causal_mask = offs_n[None, :] <= (offs_m[:, None] + (S_kv - S_q)) + qk = tl.where(causal_mask, qk, float("-inf")) + + # Invalid Q rows: load LSE with ``+inf`` so P underflows to 0 and these + # rows contribute nothing to dK/dV. + lse_base = LSE_ptr + off_b * sLSE_b + off_h * sLSE_h + lse = tl.load(lse_base + offs_m * sLSE_s, mask=mask_m, other=float("inf")) + delta_base = Delta_ptr + off_b * sD_b + off_h * sD_h + delta = tl.load(delta_base + offs_m * sD_s, mask=mask_m, other=0.0) + + do_base = DO_ptr + off_b * sDO_b + off_h * sDO_h + do_ptrs = do_base + offs_m[:, None] * sDO_s + offs_d_v[None, :] * sDO_d + do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d_v[None, :], other=0.0) + + p = tl.exp(qk - lse[:, None]) # [BLOCK_M, BLOCK_N] + + # dV += P^T @ dO + dv += tl.dot(tl.trans(p).to(do.dtype), do) + + # dP = dO @ V^T -> [BLOCK_M, BLOCK_N] + dp = tl.dot(do, tl.trans(v)) + ds = (p * (dp - delta[:, None])) * softmax_scale + # dK += dS^T @ Q + dk += tl.dot(tl.trans(ds).to(q.dtype), q) + + dk_base = DK_ptr + off_b * sDK_b + off_h * sDK_h + dk_ptrs = dk_base + offs_n[:, None] * sDK_s + offs_d_qk[None, :] * sDK_d + tl.store( + dk_ptrs, + dk.to(DK_ptr.dtype.element_ty), + mask=mask_n[:, None] & mask_d_qk[None, :], + ) + + dv_base = DV_ptr + off_b * sDV_b + off_h * sDV_h + dv_ptrs = dv_base + offs_n[:, None] * sDV_s + offs_d_v[None, :] * sDV_d + tl.store( + dv_ptrs, + dv.to(DV_ptr.dtype.element_ty), + mask=mask_n[:, None] & mask_d_v[None, :], + ) + + +# --------------------------------------------------------------------------- +# Decode forward kernel (absorbed up-projection over compressed KV cache) +# --------------------------------------------------------------------------- + + +def _mla_decode_fwd_configs(): + # Decode tiles are smaller than prefill because the effective K dim is + # ``kv_lora_rank`` (e.g. 512) which is much larger than a normal head dim, + # so SMEM pressure is high. Triton's autotune will silently prune configs + # that exceed the SM80 SMEM budget. + block_m = [16, 32, 64] + block_n = [16, 32, 64] + num_warps = [4, 8] + num_stages = [2, 3] + configs = [] + for m, n, w, s in itertools.product(block_m, block_n, num_warps, num_stages): + configs.append( + triton.Config( + {"BLOCK_M": m, "BLOCK_N": n}, + num_warps=w, + num_stages=s, + ) + ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +@triton.autotune( + configs=_mla_decode_fwd_configs(), + key=["S_q", "S_kv", "R", "R_rope", "IS_CAUSAL"], +) +@triton.jit +def _mla_decode_attn_fwd( + QN_ptr, # (B, H, S_q, R) Q_nope_abs (Q_nope already multiplied by W_uk^T) + QR_ptr, # (B, H, S_q, R_rope) Q_rope + CKV_ptr, # (B, S_kv, R) compressed KV cache (shared across heads) + KR_ptr, # (B, S_kv, R_rope) decoupled rope key (shared across heads) + O_ptr, # (B, H, S_q, R) O_inter (caller applies W_uv) + LSE_ptr, # (B, H, S_q) fp32 LSE + softmax_scale, + B, + H, + S_q, + S_kv, + R, + R_rope, + # Q_nope_abs strides + sQN_b, + sQN_h, + sQN_s, + sQN_r: tl.constexpr, + # Q_rope strides + sQR_b, + sQR_h, + sQR_s, + sQR_r: tl.constexpr, + # c_kv strides (no H) + sCKV_b, + sCKV_s, + sCKV_r: tl.constexpr, + # k_rope strides (no H) + sKR_b, + sKR_s, + sKR_r: tl.constexpr, + # O_inter strides + sO_b, + sO_h, + sO_s, + sO_r: tl.constexpr, + # LSE strides + sLSE_b, + sLSE_h, + sLSE_s: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL_R: tl.constexpr, # next_pow2(R) + BLOCK_DMODEL_RR: tl.constexpr, # next_pow2(R_rope) +): + """Absorbed MLA decode forward — one BLOCK_M tile of Q rows for one (b, h). + + Computes ``score = Q_nope_abs @ c_kv^T + Q_rope @ k_rope^T`` (scaled), + softmax, then ``O_inter = P @ c_kv``. ``c_kv`` is reused as both K (for the + nope-side score) and V; per-head ``K_nope`` / ``V`` are never materialized. + """ + pid_m = tl.program_id(0) + pid_bh = tl.program_id(1) + tl.assume(pid_m >= 0) + tl.assume(pid_bh >= 0) + + off_b = pid_bh // H + off_h = pid_bh % H + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n_init = tl.arange(0, BLOCK_N) + offs_r = tl.arange(0, BLOCK_DMODEL_R) + offs_rr = tl.arange(0, BLOCK_DMODEL_RR) + + mask_m = offs_m < S_q + mask_r = offs_r < R + mask_rr = offs_rr < R_rope + + # Load Q_nope_abs and Q_rope tiles once. + qn_base = QN_ptr + off_b * sQN_b + off_h * sQN_h + qn_ptrs = qn_base + offs_m[:, None] * sQN_s + offs_r[None, :] * sQN_r + qn = tl.load(qn_ptrs, mask=mask_m[:, None] & mask_r[None, :], other=0.0) + + qr_base = QR_ptr + off_b * sQR_b + off_h * sQR_h + qr_ptrs = qr_base + offs_m[:, None] * sQR_s + offs_rr[None, :] * sQR_r + qr = tl.load(qr_ptrs, mask=mask_m[:, None] & mask_rr[None, :], other=0.0) + + ckv_base = CKV_ptr + off_b * sCKV_b + kr_base = KR_ptr + off_b * sKR_b + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - 1.0e6 + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_R], dtype=tl.float32) + + if IS_CAUSAL: + hi = tl.minimum((pid_m + 1) * BLOCK_M + (S_kv - S_q), S_kv) + else: + hi = S_kv + + for start_n in range(0, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + offs_n = start_n + offs_n_init + mask_n = offs_n < S_kv + + # c_kv tile [BLOCK_N, BLOCK_DMODEL_R] — used for both the nope-side + # score and as the value tile. + ckv_ptrs = ckv_base + offs_n[:, None] * sCKV_s + offs_r[None, :] * sCKV_r + ckv = tl.load(ckv_ptrs, mask=mask_n[:, None] & mask_r[None, :], other=0.0) + + # k_rope tile [BLOCK_N, BLOCK_DMODEL_RR] + kr_ptrs = kr_base + offs_n[:, None] * sKR_s + offs_rr[None, :] * sKR_r + kr = tl.load(kr_ptrs, mask=mask_n[:, None] & mask_rr[None, :], other=0.0) + + # Scores: nope side is Q_nope_abs @ c_kv^T, rope side is Q_rope @ k_rope^T. + score_nope = tl.dot(qn, tl.trans(ckv)) + score_rope = tl.dot(qr, tl.trans(kr)) + qk = (score_nope + score_rope) * softmax_scale + + qk = tl.where(mask_n[None, :], qk, float("-inf")) + if IS_CAUSAL: + causal_mask = offs_n[None, :] <= (offs_m[:, None] + (S_kv - S_q)) + qk = tl.where(causal_mask, qk, float("-inf")) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.exp(m_i - m_ij) + p = tl.exp(qk - m_ij[:, None]) + l_i = l_i * alpha + tl.sum(p, 1) + acc = acc * alpha[:, None] + # Reuse the c_kv tile as V: O_inter += P @ c_kv. + acc = tl.dot(p.to(ckv.dtype), ckv, acc) + + m_i = m_ij + + acc = acc / l_i[:, None] + o_base = O_ptr + off_b * sO_b + off_h * sO_h + o_ptrs = o_base + offs_m[:, None] * sO_s + offs_r[None, :] * sO_r + tl.store( + o_ptrs, + acc.to(O_ptr.dtype.element_ty), + mask=mask_m[:, None] & mask_r[None, :], + ) + + lse_base = LSE_ptr + off_b * sLSE_b + off_h * sLSE_h + lse_ptrs = lse_base + offs_m * sLSE_s + tl.store(lse_ptrs, m_i + tl.log(l_i), mask=mask_m) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 4104820a1c..2bb3ede65c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -2175,3 +2175,44 @@ def forward( return output[0].view(*output[0].shape[:-2], -1), output[1] # ...hd -> ...(hd) return output.view(*output.shape[:-2], -1) + + +class MLATritonAttention(torch.nn.Module): + """Triton MLA backend wrapping :func:`transformer_engine.pytorch.triton.mla.mla_attention`. + + Opt-in path for SM80 MLA-shaped attention (``head_dim_qk != head_dim_v``). + Activated by setting ``NVTE_MLA_TRITON=1``; see + :class:`DotProductAttention` for the dispatch conditions. Inputs are + expected to already be in ``[B, S, H, D]`` (``bshd``) or + ``[S, B, H, D]`` (``sbhd``) layout — ``thd`` and ``no_mask``-padding + variants fall through to the regular cascade. + """ + + def __init__(self, softmax_scale: float, attention_dropout: float = 0.0) -> None: + super().__init__() + self.softmax_scale = softmax_scale + if attention_dropout != 0.0: + raise ValueError( + "MLATritonAttention does not support attention_dropout > 0." + ) + + def forward( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + qkv_format: str, + is_causal: bool, + ) -> torch.Tensor: + # Imported lazily so that environments without triton.mla don't pay + # the import cost when this backend is never selected. + from transformer_engine.pytorch.triton.mla import mla_attention + + return mla_attention( + query_layer, + key_layer, + value_layer, + softmax_scale=self.softmax_scale, + is_causal=is_causal, + qkv_format=qkv_format, + ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 17e9a337a4..8044fc9816 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -60,6 +60,7 @@ UnfusedDotProductAttention, FusedAttention, FlashAttention, + MLATritonAttention, ) @@ -505,6 +506,15 @@ def __init__( return_max_logit=self.return_max_logit, ) + # Triton MLA backend for SM80 (A100). Opt-in via NVTE_MLA_TRITON=1; see + # :meth:`forward` for the dispatch conditions. Always instantiated so + # the env var can be flipped at runtime; the underlying triton kernel + # is JIT-compiled lazily on first invocation. + self.mla_triton_attention = MLATritonAttention( + softmax_scale, + attention_dropout=attention_dropout, + ) if attention_dropout == 0.0 else None + def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ Temporarily remove core_attention._extra_state as a missing key @@ -1500,6 +1510,43 @@ def forward( " disabling all backends." ) + # Optional MLA Triton fast path. Opt-in via NVTE_MLA_TRITON=1 and + # only activated when every supported precondition holds. Any + # mismatch (FP8, dropout, CP, sliding window, alibi, padding mask, + # cross-attn with top-left causal, etc.) silently falls through to + # the regular FA/Fused/Unfused cascade — defaulting to no behavior + # change when the env var is unset. + if ( + int(os.getenv("NVTE_MLA_TRITON", "0")) + and self.mla_triton_attention is not None + and head_dim_qk != head_dim_v + and head_dim_qk in (128, 192, 256) and head_dim_v in (64, 128) + and query_layer.dtype in (torch.bfloat16, torch.float16) + and not self.fp8 + and not context_parallel + and core_attention_bias_type == "no_bias" + and alibi_slopes is None + and inference_params is None + and self.softmax_type == "vanilla" + and qkv_format in ("bshd", "sbhd") + and attn_mask_type in (None, "no_mask", "causal", "causal_bottom_right") + and window_size in (None, (-1, -1), (-1, 0)) + and torch.cuda.is_available() + and torch.cuda.get_device_capability(query_layer.device)[0] >= 8 + ): + # TE's "causal" is bottom-right when S_q != S_kv via + # bottom_right_diagonal=True (default). Our kernel right-aligns + # the diagonal by construction, so both map to is_causal=True. + is_causal_dispatch = attn_mask_type in ("causal", "causal_bottom_right") + self.logger.info("Running with MLATriton backend (NVTE_MLA_TRITON=1)") + return self.mla_triton_attention( + query_layer, + key_layer, + value_layer, + qkv_format=qkv_format, + is_causal=is_causal_dispatch, + ) + # run attention softmax_offset = ( self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32) diff --git a/transformer_engine/pytorch/triton/__init__.py b/transformer_engine/pytorch/triton/__init__.py index 6d3141253d..8ec7af220b 100644 --- a/transformer_engine/pytorch/triton/__init__.py +++ b/transformer_engine/pytorch/triton/__init__.py @@ -4,3 +4,4 @@ """PyTorch wrappers for Triton kernels.""" from transformer_engine.pytorch.triton import mhc +from transformer_engine.pytorch.triton import mla diff --git a/transformer_engine/pytorch/triton/mla.py b/transformer_engine/pytorch/triton/mla.py new file mode 100644 index 0000000000..20b32dec44 --- /dev/null +++ b/transformer_engine/pytorch/triton/mla.py @@ -0,0 +1,540 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""PyTorch wrapper for the Triton MLA attention kernels. + +Public entrypoint: :func:`mla_attention`. Forward and backward both run on +Triton kernels (FA-2 style). The pure-PyTorch :func:`mla_attention_ref` is +kept as the test reference. + +The kernels do NOT use ``tl.atomic_add`` (forward owns one O slice per program; +backward uses two passes — dQ owned by Q-tile programs, dK/dV owned by K-tile +programs). Results are deterministic, so no ``NVTE_ALLOW_NONDETERMINISTIC_ALGO`` +gate is needed. +""" + +from typing import Optional + +import torch +import triton + +from transformer_engine.common.triton.mla import ( + _mla_attn_fwd, + _mla_attn_bwd_preprocess, + _mla_attn_bwd_dq, + _mla_attn_bwd_dkv, + _mla_decode_attn_fwd, +) + + +_SUPPORTED_QKV_FORMATS = ("bshd", "bhsd", "sbhd") + + +def _user_to_bhsd(t: torch.Tensor, qkv_format: str) -> torch.Tensor: + """User-layout tensor -> contiguous BHSD.""" + if qkv_format == "bhsd": + return t.contiguous() if not t.is_contiguous() else t + if qkv_format == "bshd": + # [B, S, H, D] -> [B, H, S, D] + return t.transpose(1, 2).contiguous() + if qkv_format == "sbhd": + # [S, B, H, D] -> [B, H, S, D] + return t.permute(1, 2, 0, 3).contiguous() + raise ValueError( + f"Unsupported qkv_format: {qkv_format!r}. Expected one of {_SUPPORTED_QKV_FORMATS}." + ) + + +def _bhsd_to_user(t: torch.Tensor, qkv_format: str) -> torch.Tensor: + """BHSD tensor -> contiguous user layout.""" + if qkv_format == "bhsd": + return t + if qkv_format == "bshd": + # [B, H, S, D] -> [B, S, H, D] + return t.transpose(1, 2).contiguous() + if qkv_format == "sbhd": + # [B, H, S, D] -> [S, B, H, D] + return t.permute(2, 0, 1, 3).contiguous() + raise ValueError( + f"Unsupported qkv_format: {qkv_format!r}. Expected one of {_SUPPORTED_QKV_FORMATS}." + ) + + +def mla_attention_ref( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + is_causal: bool = False, +) -> torch.Tensor: + """Pure-PyTorch reference MLA attention in BHSD layout. + + Supports ``head_dim_qk != head_dim_v``. Internal compute is fp32; output is + cast back to the input dtype. Right-aligned causal mask matches the kernel: + row ``i`` attends to keys ``j <= i + (S_kv - S_q)``. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + in_dtype = q.dtype + q32 = q.float() + k32 = k.float() + v32 = v.float() + + s = torch.matmul(q32, k32.transpose(-1, -2)) * softmax_scale + if is_causal: + s_q, s_kv = s.shape[-2], s.shape[-1] + row = torch.arange(s_q, device=s.device).unsqueeze(-1) + col = torch.arange(s_kv, device=s.device).unsqueeze(0) + mask = col > (row + (s_kv - s_q)) + s = s.masked_fill(mask, float("-inf")) + p = torch.softmax(s, dim=-1) + out = torch.matmul(p, v32) + return out.to(in_dtype) + + +def _launch_mla_fwd(q_bhsd, k_bhsd, v_bhsd, softmax_scale, is_causal): + B, H, S_q, D_qk = q_bhsd.shape + S_kv = k_bhsd.shape[2] + D_v = v_bhsd.shape[3] + + o = torch.empty((B, H, S_q, D_v), device=q_bhsd.device, dtype=q_bhsd.dtype) + lse = torch.empty((B, H, S_q), device=q_bhsd.device, dtype=torch.float32) + + block_dmodel_qk = max(16, triton.next_power_of_2(D_qk)) + block_dmodel_v = max(16, triton.next_power_of_2(D_v)) + + grid = lambda meta: (triton.cdiv(S_q, meta["BLOCK_M"]), B * H) + + _mla_attn_fwd[grid]( + q_bhsd, + k_bhsd, + v_bhsd, + o, + lse, + float(softmax_scale), + B, + H, + S_q, + S_kv, + D_qk, + D_v, + # Q strides + q_bhsd.stride(0), + q_bhsd.stride(1), + q_bhsd.stride(2), + 1, + # K strides + k_bhsd.stride(0), + k_bhsd.stride(1), + k_bhsd.stride(2), + 1, + # V strides + v_bhsd.stride(0), + v_bhsd.stride(1), + v_bhsd.stride(2), + 1, + # O strides + o.stride(0), + o.stride(1), + o.stride(2), + 1, + # LSE strides + lse.stride(0), + lse.stride(1), + 1, + IS_CAUSAL=is_causal, + BLOCK_DMODEL_QK=block_dmodel_qk, + BLOCK_DMODEL_V=block_dmodel_v, + ) + return o, lse + + +# Preprocess kernel uses a fixed BLOCK_M (it's I/O-bound and the choice barely +# matters); the dQ / dKV kernels autotune and pick their own BLOCK_M / BLOCK_N. +_BWD_PREPROCESS_BLOCK_M = 128 + + +def _launch_mla_bwd(q_bhsd, k_bhsd, v_bhsd, o_bhsd, lse, do_bhsd, softmax_scale, is_causal): + """Run the three backward kernels and return ``(dQ, dK, dV)`` in BHSD.""" + B, H, S_q, D_qk = q_bhsd.shape + S_kv = k_bhsd.shape[2] + D_v = v_bhsd.shape[3] + + block_dmodel_qk = max(16, triton.next_power_of_2(D_qk)) + block_dmodel_v = max(16, triton.next_power_of_2(D_v)) + + delta = torch.empty((B, H, S_q), device=q_bhsd.device, dtype=torch.float32) + dq = torch.empty_like(q_bhsd) + dk = torch.empty_like(k_bhsd) + dv = torch.empty_like(v_bhsd) + + # 1) Delta = rowsum(O * dO) + grid_pre = (triton.cdiv(S_q, _BWD_PREPROCESS_BLOCK_M), B * H) + _mla_attn_bwd_preprocess[grid_pre]( + o_bhsd, + do_bhsd, + delta, + B, + H, + S_q, + D_v, + # O strides + o_bhsd.stride(0), + o_bhsd.stride(1), + o_bhsd.stride(2), + 1, + # dO strides + do_bhsd.stride(0), + do_bhsd.stride(1), + do_bhsd.stride(2), + 1, + # Delta strides + delta.stride(0), + delta.stride(1), + 1, + BLOCK_M=_BWD_PREPROCESS_BLOCK_M, + BLOCK_DMODEL_V=block_dmodel_v, + ) + + common_args = ( + q_bhsd, + k_bhsd, + v_bhsd, + do_bhsd, + lse, + delta, + ) + common_shape_args = (B, H, S_q, S_kv, D_qk, D_v) + common_strides = ( + # Q + q_bhsd.stride(0), q_bhsd.stride(1), q_bhsd.stride(2), 1, + # K + k_bhsd.stride(0), k_bhsd.stride(1), k_bhsd.stride(2), 1, + # V + v_bhsd.stride(0), v_bhsd.stride(1), v_bhsd.stride(2), 1, + # dO + do_bhsd.stride(0), do_bhsd.stride(1), do_bhsd.stride(2), 1, + # LSE + lse.stride(0), lse.stride(1), 1, + # Delta + delta.stride(0), delta.stride(1), 1, + ) + + # 2) dQ + grid_dq = lambda meta: (triton.cdiv(S_q, meta["BLOCK_M"]), B * H) + _mla_attn_bwd_dq[grid_dq]( + *common_args, + dq, + float(softmax_scale), + *common_shape_args, + *common_strides, + # dQ strides + dq.stride(0), dq.stride(1), dq.stride(2), 1, + IS_CAUSAL=is_causal, + BLOCK_DMODEL_QK=block_dmodel_qk, + BLOCK_DMODEL_V=block_dmodel_v, + ) + + # 3) dK, dV + grid_dkv = lambda meta: (triton.cdiv(S_kv, meta["BLOCK_N"]), B * H) + _mla_attn_bwd_dkv[grid_dkv]( + *common_args, + dk, + dv, + float(softmax_scale), + *common_shape_args, + *common_strides, + # dK strides + dk.stride(0), dk.stride(1), dk.stride(2), 1, + # dV strides + dv.stride(0), dv.stride(1), dv.stride(2), 1, + IS_CAUSAL=is_causal, + BLOCK_DMODEL_QK=block_dmodel_qk, + BLOCK_DMODEL_V=block_dmodel_v, + ) + + return dq, dk, dv + + +class MLAttentionFn(torch.autograd.Function): + """Forward and backward via Triton kernels (FA-2 style).""" + + @staticmethod + def forward(ctx, q, k, v, softmax_scale, is_causal, qkv_format): + if qkv_format not in _SUPPORTED_QKV_FORMATS: + raise ValueError( + f"qkv_format must be one of {_SUPPORTED_QKV_FORMATS}, got {qkv_format!r}" + ) + if q.dtype not in (torch.float16, torch.bfloat16): + raise ValueError( + f"mla_attention requires fp16 or bf16 inputs, got {q.dtype}" + ) + if not (q.dtype == k.dtype == v.dtype): + raise ValueError("q, k, v must share the same dtype") + if not q.is_cuda: + raise ValueError("mla_attention requires CUDA tensors") + + q_bhsd = _user_to_bhsd(q, qkv_format) + k_bhsd = _user_to_bhsd(k, qkv_format) + v_bhsd = _user_to_bhsd(v, qkv_format) + + if q_bhsd.shape[3] != k_bhsd.shape[3]: + raise ValueError( + "q.head_dim and k.head_dim must match" + f" (got {q_bhsd.shape[3]} vs {k_bhsd.shape[3]})" + ) + if q_bhsd.shape[:2] != k_bhsd.shape[:2] or q_bhsd.shape[:2] != v_bhsd.shape[:2]: + raise ValueError("q, k, v must share batch and head dimensions") + if k_bhsd.shape[2] != v_bhsd.shape[2]: + raise ValueError("k.seq_len and v.seq_len must match") + + o_bhsd, lse = _launch_mla_fwd(q_bhsd, k_bhsd, v_bhsd, softmax_scale, is_causal) + + ctx.save_for_backward(q_bhsd, k_bhsd, v_bhsd, o_bhsd, lse) + ctx.softmax_scale = softmax_scale + ctx.is_causal = is_causal + ctx.qkv_format = qkv_format + + return _bhsd_to_user(o_bhsd, qkv_format) + + @staticmethod + def backward(ctx, grad_o): + q_bhsd, k_bhsd, v_bhsd, o_bhsd, lse = ctx.saved_tensors + grad_o_bhsd = _user_to_bhsd(grad_o, ctx.qkv_format) + + dq_bhsd, dk_bhsd, dv_bhsd = _launch_mla_bwd( + q_bhsd, + k_bhsd, + v_bhsd, + o_bhsd, + lse, + grad_o_bhsd, + ctx.softmax_scale, + ctx.is_causal, + ) + + grad_q = _bhsd_to_user(dq_bhsd, ctx.qkv_format) + grad_k = _bhsd_to_user(dk_bhsd, ctx.qkv_format) + grad_v = _bhsd_to_user(dv_bhsd, ctx.qkv_format) + return grad_q, grad_k, grad_v, None, None, None + + +def mla_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + softmax_scale: Optional[float] = None, + is_causal: bool = False, + qkv_format: str = "bshd", +) -> torch.Tensor: + """Triton MLA attention (FA-2 style forward and backward) for SM80+. + + Supports ``head_dim_qk != head_dim_v`` (DeepSeek-V2/V3 style). Both forward + and backward run as Triton kernels: the forward saves an fp32 ``LSE``, which + the backward consumes to recompute softmax probabilities without extra + memory. + + Parameters + ---------- + q, k, v + Tensors in ``qkv_format`` layout. ``q`` and ``k`` must share the last + dim (head_dim_qk); ``v`` may have a different last dim (head_dim_v). + Dtype must be fp16 or bf16. Batch and head dims must match across the + three tensors; ``k`` and ``v`` must share seq_len. + softmax_scale + Multiplier applied to ``Q @ K^T`` before softmax. Defaults to + ``1 / sqrt(head_dim_qk)``. + is_causal + If True, applies a right-aligned causal mask: row ``i`` attends to + keys ``j <= i + (S_kv - S_q)``. With ``S_q == S_kv`` this is the + standard causal self-attention mask. + qkv_format + ``"bshd"`` (default, matches FlashAttention), ``"bhsd"`` (matches + ``F.scaled_dot_product_attention``), or ``"sbhd"`` (megatron-style). + + Returns + ------- + torch.Tensor + Attention output in ``qkv_format`` layout, dtype matching the inputs. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + return MLAttentionFn.apply(q, k, v, softmax_scale, is_causal, qkv_format) + + +# --------------------------------------------------------------------------- +# Decode (absorbed up-projection over compressed KV cache) +# --------------------------------------------------------------------------- + + +def mla_decode_attention_ref( + q_nope_abs: torch.Tensor, + q_rope: torch.Tensor, + c_kv: torch.Tensor, + k_rope: torch.Tensor, + softmax_scale: float, + is_causal: bool = False, +) -> torch.Tensor: + """Pure-PyTorch reference for absorbed MLA decode attention. + + Inputs (BHSD layout for Q-side, BSR layout for compressed cache): + + - ``q_nope_abs`` ``[B, H, S_q, R]`` — Q's nope side already multiplied by + ``W_uk^T``. ``R`` is ``kv_lora_rank``. + - ``q_rope`` ``[B, H, S_q, R_rope]`` — Q's rope side. + - ``c_kv`` ``[B, S_kv, R]`` — compressed KV cache, shared across heads. + - ``k_rope`` ``[B, S_kv, R_rope]`` — decoupled rope keys, shared across heads. + - ``softmax_scale`` — typically ``1 / sqrt(head_dim_qk_orig)`` where + ``head_dim_qk_orig = q_nope_dim + R_rope`` (the *un-absorbed* head dim). + + Returns ``o_inter`` ``[B, H, S_q, R]``. The caller multiplies by ``W_uv`` + (per head) to obtain the final attention output. + """ + in_dtype = q_nope_abs.dtype + qn = q_nope_abs.float() + qr = q_rope.float() + ck = c_kv.float() + kr = k_rope.float() + + score_nope = torch.einsum("bhmr,bnr->bhmn", qn, ck) + score_rope = torch.einsum("bhmr,bnr->bhmn", qr, kr) + s = (score_nope + score_rope) * softmax_scale + + if is_causal: + s_q, s_kv = s.shape[-2], s.shape[-1] + row = torch.arange(s_q, device=s.device).unsqueeze(-1) + col = torch.arange(s_kv, device=s.device).unsqueeze(0) + mask = col > (row + (s_kv - s_q)) + s = s.masked_fill(mask, float("-inf")) + + p = torch.softmax(s, dim=-1) + o_inter = torch.einsum("bhmn,bnr->bhmr", p, ck) + return o_inter.to(in_dtype) + + +def _launch_mla_decode_fwd(q_nope_abs, q_rope, c_kv, k_rope, softmax_scale, is_causal): + B, H, S_q, R = q_nope_abs.shape + S_kv = c_kv.shape[1] + R_rope = q_rope.shape[3] + + o_inter = torch.empty( + (B, H, S_q, R), device=q_nope_abs.device, dtype=q_nope_abs.dtype + ) + lse = torch.empty((B, H, S_q), device=q_nope_abs.device, dtype=torch.float32) + + block_dmodel_r = max(16, triton.next_power_of_2(R)) + block_dmodel_rr = max(16, triton.next_power_of_2(R_rope)) + + grid = lambda meta: (triton.cdiv(S_q, meta["BLOCK_M"]), B * H) + + _mla_decode_attn_fwd[grid]( + q_nope_abs, + q_rope, + c_kv, + k_rope, + o_inter, + lse, + float(softmax_scale), + B, + H, + S_q, + S_kv, + R, + R_rope, + # Q_nope_abs strides + q_nope_abs.stride(0), q_nope_abs.stride(1), q_nope_abs.stride(2), 1, + # Q_rope strides + q_rope.stride(0), q_rope.stride(1), q_rope.stride(2), 1, + # c_kv strides (no H) + c_kv.stride(0), c_kv.stride(1), 1, + # k_rope strides (no H) + k_rope.stride(0), k_rope.stride(1), 1, + # O_inter strides + o_inter.stride(0), o_inter.stride(1), o_inter.stride(2), 1, + # LSE strides + lse.stride(0), lse.stride(1), 1, + IS_CAUSAL=is_causal, + BLOCK_DMODEL_R=block_dmodel_r, + BLOCK_DMODEL_RR=block_dmodel_rr, + ) + return o_inter, lse + + +def mla_decode_attention( + q_nope_abs: torch.Tensor, + q_rope: torch.Tensor, + c_kv: torch.Tensor, + k_rope: torch.Tensor, + *, + softmax_scale: float, + is_causal: bool = False, +) -> torch.Tensor: + """Triton MLA decode forward over a compressed KV cache (SM80+). + + Implements the FlashMLA-style absorbed up-projection: ``c_kv`` plays the + role of both K (nope side) and V, so per-head K/V are never materialized + inside the kernel. The caller is expected to apply ``W_uv`` to the + returned ``O_inter`` to obtain the final attention output. + + Parameters + ---------- + q_nope_abs : ``[B, H, S_q, R]`` + Q's nope side after absorbing ``W_uk^T`` (so its last-dim is + ``kv_lora_rank``, not the original Q nope dim). + q_rope : ``[B, H, S_q, R_rope]`` + Q's rope side. + c_kv : ``[B, S_kv, R]`` + Compressed KV cache, shared across heads. + k_rope : ``[B, S_kv, R_rope]`` + Decoupled rope keys, shared across heads. + softmax_scale : float + Required. Typically ``1 / sqrt(q_nope_dim + R_rope)`` — i.e. the + original head_dim_qk before absorption. **No default**: the absorbed + Q's last dim (``R``) is *not* the right denominator, so the kernel + will not guess. + is_causal : bool + Right-aligned causal: row ``i`` attends to keys ``j <= i + (S_kv - S_q)``. + + Returns + ------- + torch.Tensor + ``O_inter`` of shape ``[B, H, S_q, R]`` in the input dtype. Apply + ``W_uv`` outside the kernel for the final attention output. + """ + if q_nope_abs.dtype not in (torch.float16, torch.bfloat16): + raise ValueError( + f"mla_decode_attention requires fp16 or bf16 inputs, got {q_nope_abs.dtype}" + ) + if not ( + q_nope_abs.dtype == q_rope.dtype == c_kv.dtype == k_rope.dtype + ): + raise ValueError("q_nope_abs, q_rope, c_kv, k_rope must share dtype") + if not q_nope_abs.is_cuda: + raise ValueError("mla_decode_attention requires CUDA tensors") + + if q_nope_abs.shape[-1] != c_kv.shape[-1]: + raise ValueError( + "q_nope_abs.last_dim and c_kv.last_dim must match (kv_lora_rank);" + f" got {q_nope_abs.shape[-1]} vs {c_kv.shape[-1]}" + ) + if q_rope.shape[-1] != k_rope.shape[-1]: + raise ValueError( + "q_rope.last_dim and k_rope.last_dim must match (rope_dim);" + f" got {q_rope.shape[-1]} vs {k_rope.shape[-1]}" + ) + if q_nope_abs.shape[:3] != q_rope.shape[:3]: + raise ValueError("q_nope_abs and q_rope must share (B, H, S_q)") + if c_kv.shape[:2] != k_rope.shape[:2]: + raise ValueError("c_kv and k_rope must share (B, S_kv)") + if q_nope_abs.shape[0] != c_kv.shape[0]: + raise ValueError("Q-side and KV-cache batch dims must match") + + qn = q_nope_abs.contiguous() if not q_nope_abs.is_contiguous() else q_nope_abs + qr = q_rope.contiguous() if not q_rope.is_contiguous() else q_rope + ck = c_kv.contiguous() if not c_kv.is_contiguous() else c_kv + kr = k_rope.contiguous() if not k_rope.is_contiguous() else k_rope + + o_inter, _lse = _launch_mla_decode_fwd(qn, qr, ck, kr, softmax_scale, is_causal) + return o_inter From 367d037ec002cede5ac3e2dba7b3bec4a17560b2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Apr 2026 23:23:14 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_mla_triton.py | 31 ++++--- transformer_engine/common/triton/mla.py | 14 +-- .../dot_product_attention/backends.py | 4 +- .../dot_product_attention.py | 15 ++-- transformer_engine/pytorch/triton/mla.py | 85 +++++++++++++------ 5 files changed, 87 insertions(+), 62 deletions(-) diff --git a/tests/pytorch/test_mla_triton.py b/tests/pytorch/test_mla_triton.py index 2aa9c4af8d..830e077365 100644 --- a/tests/pytorch/test_mla_triton.py +++ b/tests/pytorch/test_mla_triton.py @@ -38,10 +38,7 @@ class MLAConfig: @staticmethod def desc(cfg): - return ( - f"b{cfg.b}_h{cfg.h}_sq{cfg.s_q}_skv{cfg.s_kv}" - f"_dqk{cfg.d_qk}_dv{cfg.d_v}" - ) + return f"b{cfg.b}_h{cfg.h}_sq{cfg.s_q}_skv{cfg.s_kv}_dqk{cfg.d_qk}_dv{cfg.d_v}" mla_configs = [ @@ -191,9 +188,7 @@ def test_mla_softmax_scale_default(): cfg = MLAConfig(1, 2, 64, 64, 192, 128) q, k, v = _make_qkv(cfg, torch.bfloat16, "bshd") out_default = mla_attention(q, k, v, qkv_format="bshd") - out_explicit = mla_attention( - q, k, v, softmax_scale=cfg.d_qk**-0.5, qkv_format="bshd" - ) + out_explicit = mla_attention(q, k, v, softmax_scale=cfg.d_qk**-0.5, qkv_format="bshd") torch.testing.assert_close(out_default, out_explicit, atol=0.0, rtol=0.0) @@ -221,10 +216,7 @@ class MLADecodeConfig: @staticmethod def desc(cfg): - return ( - f"b{cfg.b}_h{cfg.h}_sq{cfg.s_q}_skv{cfg.s_kv}" - f"_r{cfg.r}_rrope{cfg.r_rope}" - ) + return f"b{cfg.b}_h{cfg.h}_sq{cfg.s_q}_skv{cfg.s_kv}_r{cfg.r}_rrope{cfg.r_rope}" _decode_configs = [ @@ -256,7 +248,9 @@ def test_mla_decode_forward(cfg, dtype, is_causal): softmax_scale = 1.0 / (cfg.r + cfg.r_rope) ** 0.5 o = mla_decode_attention(qn, qr, ck, kr, softmax_scale=softmax_scale, is_causal=is_causal) - o_ref = mla_decode_attention_ref(qn, qr, ck, kr, softmax_scale=softmax_scale, is_causal=is_causal) + o_ref = mla_decode_attention_ref( + qn, qr, ck, kr, softmax_scale=softmax_scale, is_causal=is_causal + ) assert o.shape == o_ref.shape == (cfg.b, cfg.h, cfg.s_q, cfg.r) assert o.dtype == o_ref.dtype == dtype @@ -299,10 +293,13 @@ def test_dpa_dispatches_to_mla_triton(monkeypatch, dtype, is_causal, qkv_format) monkeypatch.setenv("NVTE_MLA_TRITON", "1") # Force backend cache invalidation so the env-var check re-runs. - from transformer_engine.pytorch.attention.dot_product_attention import dot_product_attention as dpa_mod + from transformer_engine.pytorch.attention.dot_product_attention import ( + dot_product_attention as dpa_mod, + ) + dpa_mod._attention_backends["backend_selection_requires_update"] = True - softmax_scale = cfg.d_qk ** -0.5 + softmax_scale = cfg.d_qk**-0.5 dpa = te.DotProductAttention( num_attention_heads=cfg.h, kv_channels=(cfg.d_qk, cfg.d_v), @@ -323,7 +320,9 @@ def test_dpa_falls_through_when_env_var_unset(monkeypatch): """With NVTE_MLA_TRITON unset, dispatch must NOT route through the MLA Triton backend even for MLA-shaped inputs (existing behavior preserved).""" import transformer_engine.pytorch as te - from transformer_engine.pytorch.attention.dot_product_attention import dot_product_attention as dpa_mod + from transformer_engine.pytorch.attention.dot_product_attention import ( + dot_product_attention as dpa_mod, + ) reset_rng_states() cfg = MLAConfig(2, 4, 256, 256, 192, 128) @@ -336,7 +335,7 @@ def test_dpa_falls_through_when_env_var_unset(monkeypatch): num_attention_heads=cfg.h, kv_channels=(cfg.d_qk, cfg.d_v), attention_dropout=0.0, - softmax_scale=cfg.d_qk ** -0.5, + softmax_scale=cfg.d_qk**-0.5, qkv_format="bshd", attn_mask_type="causal", ).cuda() diff --git a/transformer_engine/common/triton/mla.py b/transformer_engine/common/triton/mla.py index f4ec2a9ed0..3399713643 100644 --- a/transformer_engine/common/triton/mla.py +++ b/transformer_engine/common/triton/mla.py @@ -254,19 +254,9 @@ def _mla_attn_bwd_preprocess( mask_d = offs_d < D_v mask_md = mask_m[:, None] & mask_d[None, :] - o_ptrs = ( - O_ptr - + off_b * sO_b - + off_h * sO_h - + offs_m[:, None] * sO_s - + offs_d[None, :] * sO_d - ) + o_ptrs = O_ptr + off_b * sO_b + off_h * sO_h + offs_m[:, None] * sO_s + offs_d[None, :] * sO_d do_ptrs = ( - DO_ptr - + off_b * sDO_b - + off_h * sDO_h - + offs_m[:, None] * sDO_s - + offs_d[None, :] * sDO_d + DO_ptr + off_b * sDO_b + off_h * sDO_h + offs_m[:, None] * sDO_s + offs_d[None, :] * sDO_d ) o = tl.load(o_ptrs, mask=mask_md, other=0.0).to(tl.float32) do = tl.load(do_ptrs, mask=mask_md, other=0.0).to(tl.float32) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 2bb3ede65c..5a0a72c47a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -2192,9 +2192,7 @@ def __init__(self, softmax_scale: float, attention_dropout: float = 0.0) -> None super().__init__() self.softmax_scale = softmax_scale if attention_dropout != 0.0: - raise ValueError( - "MLATritonAttention does not support attention_dropout > 0." - ) + raise ValueError("MLATritonAttention does not support attention_dropout > 0.") def forward( self, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 8044fc9816..bfdcd4f6a3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -510,10 +510,14 @@ def __init__( # :meth:`forward` for the dispatch conditions. Always instantiated so # the env var can be flipped at runtime; the underlying triton kernel # is JIT-compiled lazily on first invocation. - self.mla_triton_attention = MLATritonAttention( - softmax_scale, - attention_dropout=attention_dropout, - ) if attention_dropout == 0.0 else None + self.mla_triton_attention = ( + MLATritonAttention( + softmax_scale, + attention_dropout=attention_dropout, + ) + if attention_dropout == 0.0 + else None + ) def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ @@ -1520,7 +1524,8 @@ def forward( int(os.getenv("NVTE_MLA_TRITON", "0")) and self.mla_triton_attention is not None and head_dim_qk != head_dim_v - and head_dim_qk in (128, 192, 256) and head_dim_v in (64, 128) + and head_dim_qk in (128, 192, 256) + and head_dim_v in (64, 128) and query_layer.dtype in (torch.bfloat16, torch.float16) and not self.fp8 and not context_parallel diff --git a/transformer_engine/pytorch/triton/mla.py b/transformer_engine/pytorch/triton/mla.py index 20b32dec44..e7d61c2112 100644 --- a/transformer_engine/pytorch/triton/mla.py +++ b/transformer_engine/pytorch/triton/mla.py @@ -208,17 +208,33 @@ def _launch_mla_bwd(q_bhsd, k_bhsd, v_bhsd, o_bhsd, lse, do_bhsd, softmax_scale, common_shape_args = (B, H, S_q, S_kv, D_qk, D_v) common_strides = ( # Q - q_bhsd.stride(0), q_bhsd.stride(1), q_bhsd.stride(2), 1, + q_bhsd.stride(0), + q_bhsd.stride(1), + q_bhsd.stride(2), + 1, # K - k_bhsd.stride(0), k_bhsd.stride(1), k_bhsd.stride(2), 1, + k_bhsd.stride(0), + k_bhsd.stride(1), + k_bhsd.stride(2), + 1, # V - v_bhsd.stride(0), v_bhsd.stride(1), v_bhsd.stride(2), 1, + v_bhsd.stride(0), + v_bhsd.stride(1), + v_bhsd.stride(2), + 1, # dO - do_bhsd.stride(0), do_bhsd.stride(1), do_bhsd.stride(2), 1, + do_bhsd.stride(0), + do_bhsd.stride(1), + do_bhsd.stride(2), + 1, # LSE - lse.stride(0), lse.stride(1), 1, + lse.stride(0), + lse.stride(1), + 1, # Delta - delta.stride(0), delta.stride(1), 1, + delta.stride(0), + delta.stride(1), + 1, ) # 2) dQ @@ -230,7 +246,10 @@ def _launch_mla_bwd(q_bhsd, k_bhsd, v_bhsd, o_bhsd, lse, do_bhsd, softmax_scale, *common_shape_args, *common_strides, # dQ strides - dq.stride(0), dq.stride(1), dq.stride(2), 1, + dq.stride(0), + dq.stride(1), + dq.stride(2), + 1, IS_CAUSAL=is_causal, BLOCK_DMODEL_QK=block_dmodel_qk, BLOCK_DMODEL_V=block_dmodel_v, @@ -246,9 +265,15 @@ def _launch_mla_bwd(q_bhsd, k_bhsd, v_bhsd, o_bhsd, lse, do_bhsd, softmax_scale, *common_shape_args, *common_strides, # dK strides - dk.stride(0), dk.stride(1), dk.stride(2), 1, + dk.stride(0), + dk.stride(1), + dk.stride(2), + 1, # dV strides - dv.stride(0), dv.stride(1), dv.stride(2), 1, + dv.stride(0), + dv.stride(1), + dv.stride(2), + 1, IS_CAUSAL=is_causal, BLOCK_DMODEL_QK=block_dmodel_qk, BLOCK_DMODEL_V=block_dmodel_v, @@ -267,9 +292,7 @@ def forward(ctx, q, k, v, softmax_scale, is_causal, qkv_format): f"qkv_format must be one of {_SUPPORTED_QKV_FORMATS}, got {qkv_format!r}" ) if q.dtype not in (torch.float16, torch.bfloat16): - raise ValueError( - f"mla_attention requires fp16 or bf16 inputs, got {q.dtype}" - ) + raise ValueError(f"mla_attention requires fp16 or bf16 inputs, got {q.dtype}") if not (q.dtype == k.dtype == v.dtype): raise ValueError("q, k, v must share the same dtype") if not q.is_cuda: @@ -281,8 +304,7 @@ def forward(ctx, q, k, v, softmax_scale, is_causal, qkv_format): if q_bhsd.shape[3] != k_bhsd.shape[3]: raise ValueError( - "q.head_dim and k.head_dim must match" - f" (got {q_bhsd.shape[3]} vs {k_bhsd.shape[3]})" + f"q.head_dim and k.head_dim must match (got {q_bhsd.shape[3]} vs {k_bhsd.shape[3]})" ) if q_bhsd.shape[:2] != k_bhsd.shape[:2] or q_bhsd.shape[:2] != v_bhsd.shape[:2]: raise ValueError("q, k, v must share batch and head dimensions") @@ -419,9 +441,7 @@ def _launch_mla_decode_fwd(q_nope_abs, q_rope, c_kv, k_rope, softmax_scale, is_c S_kv = c_kv.shape[1] R_rope = q_rope.shape[3] - o_inter = torch.empty( - (B, H, S_q, R), device=q_nope_abs.device, dtype=q_nope_abs.dtype - ) + o_inter = torch.empty((B, H, S_q, R), device=q_nope_abs.device, dtype=q_nope_abs.dtype) lse = torch.empty((B, H, S_q), device=q_nope_abs.device, dtype=torch.float32) block_dmodel_r = max(16, triton.next_power_of_2(R)) @@ -444,17 +464,32 @@ def _launch_mla_decode_fwd(q_nope_abs, q_rope, c_kv, k_rope, softmax_scale, is_c R, R_rope, # Q_nope_abs strides - q_nope_abs.stride(0), q_nope_abs.stride(1), q_nope_abs.stride(2), 1, + q_nope_abs.stride(0), + q_nope_abs.stride(1), + q_nope_abs.stride(2), + 1, # Q_rope strides - q_rope.stride(0), q_rope.stride(1), q_rope.stride(2), 1, + q_rope.stride(0), + q_rope.stride(1), + q_rope.stride(2), + 1, # c_kv strides (no H) - c_kv.stride(0), c_kv.stride(1), 1, + c_kv.stride(0), + c_kv.stride(1), + 1, # k_rope strides (no H) - k_rope.stride(0), k_rope.stride(1), 1, + k_rope.stride(0), + k_rope.stride(1), + 1, # O_inter strides - o_inter.stride(0), o_inter.stride(1), o_inter.stride(2), 1, + o_inter.stride(0), + o_inter.stride(1), + o_inter.stride(2), + 1, # LSE strides - lse.stride(0), lse.stride(1), 1, + lse.stride(0), + lse.stride(1), + 1, IS_CAUSAL=is_causal, BLOCK_DMODEL_R=block_dmodel_r, BLOCK_DMODEL_RR=block_dmodel_rr, @@ -507,9 +542,7 @@ def mla_decode_attention( raise ValueError( f"mla_decode_attention requires fp16 or bf16 inputs, got {q_nope_abs.dtype}" ) - if not ( - q_nope_abs.dtype == q_rope.dtype == c_kv.dtype == k_rope.dtype - ): + if not (q_nope_abs.dtype == q_rope.dtype == c_kv.dtype == k_rope.dtype): raise ValueError("q_nope_abs, q_rope, c_kv, k_rope must share dtype") if not q_nope_abs.is_cuda: raise ValueError("mla_decode_attention requires CUDA tensors") From e75f6671640f150f11a3cee7704635618129c57d Mon Sep 17 00:00:00 2001 From: Minho Ryu Date: Fri, 1 May 2026 08:44:54 +0900 Subject: [PATCH 3/5] [PyTorch] Address review: fix MLA Triton dispatch corner cases - Map ``window_size=(-1, 0)`` (FlashAttention causal-via-window convention) to ``is_causal=True`` in the MLA Triton dispatch. Previously a caller setting ``attn_mask_type="no_mask"`` together with ``window_size=(-1, 0)`` would land in the kernel with ``is_causal=False`` and silently attend to future tokens. - Use string equality (``== "1"``) instead of ``int(os.getenv(...))`` for ``NVTE_MLA_TRITON``, so non-integer values like ``"true"`` no longer raise ``ValueError`` at dispatch time. - ``mla_decode_attention`` now raises ``NotImplementedError`` when any input has ``requires_grad=True``. The decode kernel is launched directly (no autograd.Function wrapper) and would otherwise drop gradients silently at the kernel boundary; v1 is forward-only. Test: tests/pytorch/test_mla_triton.py adds ``test_mla_decode_rejects_requires_grad``. Signed-off-by: Minho Ryu --- tests/pytorch/test_mla_triton.py | 11 +++++++++++ .../dot_product_attention.py | 15 ++++++++++----- transformer_engine/pytorch/triton/mla.py | 9 +++++++++ 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_mla_triton.py b/tests/pytorch/test_mla_triton.py index 830e077365..3d968735a0 100644 --- a/tests/pytorch/test_mla_triton.py +++ b/tests/pytorch/test_mla_triton.py @@ -273,6 +273,17 @@ def test_mla_decode_rejects_dim_mismatch(): mla_decode_attention(qn, qr, bad_ck, kr, softmax_scale=0.1) +def test_mla_decode_rejects_requires_grad(): + """Decode is forward-only in v1; an input with requires_grad=True must + raise NotImplementedError rather than silently dropping gradients at the + kernel boundary.""" + cfg = MLADecodeConfig(1, 2, 1, 64, 64, 16) + qn, qr, ck, kr = _make_decode_inputs(cfg, torch.bfloat16) + qn.requires_grad_(True) + with pytest.raises(NotImplementedError, match="forward-only"): + mla_decode_attention(qn, qr, ck, kr, softmax_scale=0.1) + + # --------------------------------------------------------------------------- # DotProductAttention dispatch (NVTE_MLA_TRITON=1 fast path) # --------------------------------------------------------------------------- diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index bfdcd4f6a3..a382435403 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1521,7 +1521,7 @@ def forward( # the regular FA/Fused/Unfused cascade — defaulting to no behavior # change when the env var is unset. if ( - int(os.getenv("NVTE_MLA_TRITON", "0")) + os.getenv("NVTE_MLA_TRITON", "0") == "1" and self.mla_triton_attention is not None and head_dim_qk != head_dim_v and head_dim_qk in (128, 192, 256) @@ -1539,10 +1539,15 @@ def forward( and torch.cuda.is_available() and torch.cuda.get_device_capability(query_layer.device)[0] >= 8 ): - # TE's "causal" is bottom-right when S_q != S_kv via - # bottom_right_diagonal=True (default). Our kernel right-aligns - # the diagonal by construction, so both map to is_causal=True. - is_causal_dispatch = attn_mask_type in ("causal", "causal_bottom_right") + # Causal masking can be requested via either ``attn_mask_type`` + # (TE's bottom-right alignment when S_q != S_kv matches our + # kernel's right-aligned diagonal) or via FlashAttention-style + # ``window_size=(-1, 0)`` (left=infinite, right=0 → causal). + # Both paths must map to the kernel's ``is_causal=True``. + is_causal_dispatch = ( + attn_mask_type in ("causal", "causal_bottom_right") + or window_size == (-1, 0) + ) self.logger.info("Running with MLATriton backend (NVTE_MLA_TRITON=1)") return self.mla_triton_attention( query_layer, diff --git a/transformer_engine/pytorch/triton/mla.py b/transformer_engine/pytorch/triton/mla.py index e7d61c2112..81c30e2d61 100644 --- a/transformer_engine/pytorch/triton/mla.py +++ b/transformer_engine/pytorch/triton/mla.py @@ -546,6 +546,15 @@ def mla_decode_attention( raise ValueError("q_nope_abs, q_rope, c_kv, k_rope must share dtype") if not q_nope_abs.is_cuda: raise ValueError("mla_decode_attention requires CUDA tensors") + # Forward-only in v1: the kernel is launched directly (no autograd.Function + # wrapper), so a tensor that needs gradients would silently lose them at + # the kernel boundary. Refuse loudly instead. + if any(t.requires_grad for t in (q_nope_abs, q_rope, c_kv, k_rope)): + raise NotImplementedError( + "mla_decode_attention is forward-only in v1; backward through the" + " absorbed-projection decode kernel is not implemented. Detach" + " inputs (or use torch.no_grad()) before calling." + ) if q_nope_abs.shape[-1] != c_kv.shape[-1]: raise ValueError( From 662d34ce388b0eff1c9820801d76886c7e732b2e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Apr 2026 23:46:13 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../dot_product_attention/dot_product_attention.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index a382435403..d3809b48a9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1544,10 +1544,10 @@ def forward( # kernel's right-aligned diagonal) or via FlashAttention-style # ``window_size=(-1, 0)`` (left=infinite, right=0 → causal). # Both paths must map to the kernel's ``is_causal=True``. - is_causal_dispatch = ( - attn_mask_type in ("causal", "causal_bottom_right") - or window_size == (-1, 0) - ) + is_causal_dispatch = attn_mask_type in ( + "causal", + "causal_bottom_right", + ) or window_size == (-1, 0) self.logger.info("Running with MLATriton backend (NVTE_MLA_TRITON=1)") return self.mla_triton_attention( query_layer, From 854859a1fc9fc058460f18688b52ccb2052b353a Mon Sep 17 00:00:00 2001 From: Minho Ryu Date: Fri, 1 May 2026 10:06:53 +0900 Subject: [PATCH 5/5] [Common, PyTorch] Drop unused LSE plumbing from decode kernel The decode forward kernel was allocating and writing an fp32 ``LSE`` buffer that was immediately discarded by the wrapper (v1 ships no analytical decode backward, and the prefill backward kernels only consume the prefill kernel's LSE). Remove the buffer, the kernel parameter, and the trailing ``m_i + tl.log(l_i)`` store. A short comment marks the spot for re-introduction if a future change adds decode backward. No semantic change. Decode correctness re-verified across DSv2 dims (R=512, R_rope=64) and smaller smoke configs against the fp32-internal PyTorch reference. Signed-off-by: Minho Ryu --- transformer_engine/common/triton/mla.py | 12 +++--------- transformer_engine/pytorch/triton/mla.py | 11 ++--------- 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/triton/mla.py b/transformer_engine/common/triton/mla.py index 3399713643..f6a279606b 100644 --- a/transformer_engine/common/triton/mla.py +++ b/transformer_engine/common/triton/mla.py @@ -590,7 +590,6 @@ def _mla_decode_attn_fwd( CKV_ptr, # (B, S_kv, R) compressed KV cache (shared across heads) KR_ptr, # (B, S_kv, R_rope) decoupled rope key (shared across heads) O_ptr, # (B, H, S_q, R) O_inter (caller applies W_uv) - LSE_ptr, # (B, H, S_q) fp32 LSE softmax_scale, B, H, @@ -621,10 +620,6 @@ def _mla_decode_attn_fwd( sO_h, sO_s, sO_r: tl.constexpr, - # LSE strides - sLSE_b, - sLSE_h, - sLSE_s: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, @@ -717,7 +712,6 @@ def _mla_decode_attn_fwd( acc.to(O_ptr.dtype.element_ty), mask=mask_m[:, None] & mask_r[None, :], ) - - lse_base = LSE_ptr + off_b * sLSE_b + off_h * sLSE_h - lse_ptrs = lse_base + offs_m * sLSE_s - tl.store(lse_ptrs, m_i + tl.log(l_i), mask=mask_m) + # LSE is intentionally not saved here — v1 has no analytical decode + # backward. If a future change adds one, re-introduce a fp32 LSE buffer + # and ``m_i + tl.log(l_i)`` store at this point. diff --git a/transformer_engine/pytorch/triton/mla.py b/transformer_engine/pytorch/triton/mla.py index 81c30e2d61..df73f06e6b 100644 --- a/transformer_engine/pytorch/triton/mla.py +++ b/transformer_engine/pytorch/triton/mla.py @@ -442,7 +442,6 @@ def _launch_mla_decode_fwd(q_nope_abs, q_rope, c_kv, k_rope, softmax_scale, is_c R_rope = q_rope.shape[3] o_inter = torch.empty((B, H, S_q, R), device=q_nope_abs.device, dtype=q_nope_abs.dtype) - lse = torch.empty((B, H, S_q), device=q_nope_abs.device, dtype=torch.float32) block_dmodel_r = max(16, triton.next_power_of_2(R)) block_dmodel_rr = max(16, triton.next_power_of_2(R_rope)) @@ -455,7 +454,6 @@ def _launch_mla_decode_fwd(q_nope_abs, q_rope, c_kv, k_rope, softmax_scale, is_c c_kv, k_rope, o_inter, - lse, float(softmax_scale), B, H, @@ -486,15 +484,11 @@ def _launch_mla_decode_fwd(q_nope_abs, q_rope, c_kv, k_rope, softmax_scale, is_c o_inter.stride(1), o_inter.stride(2), 1, - # LSE strides - lse.stride(0), - lse.stride(1), - 1, IS_CAUSAL=is_causal, BLOCK_DMODEL_R=block_dmodel_r, BLOCK_DMODEL_RR=block_dmodel_rr, ) - return o_inter, lse + return o_inter def mla_decode_attention( @@ -578,5 +572,4 @@ def mla_decode_attention( ck = c_kv.contiguous() if not c_kv.is_contiguous() else c_kv kr = k_rope.contiguous() if not k_rope.is_contiguous() else k_rope - o_inter, _lse = _launch_mla_decode_fwd(qn, qr, ck, kr, softmax_scale, is_causal) - return o_inter + return _launch_mla_decode_fwd(qn, qr, ck, kr, softmax_scale, is_causal)