diff --git a/tests/pytorch/test_mla_triton.py b/tests/pytorch/test_mla_triton.py new file mode 100644 index 0000000000..3d968735a0 --- /dev/null +++ b/tests/pytorch/test_mla_triton.py @@ -0,0 +1,356 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for the Triton-based MLA kernels (transformer_engine.pytorch.triton.mla).""" + +from dataclasses import dataclass + +import pytest +import torch + +from utils import reset_rng_states +from transformer_engine.pytorch.triton.mla import ( + mla_attention, + mla_attention_ref, + mla_decode_attention, + mla_decode_attention_ref, +) + + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 8, + reason="MLA Triton kernel requires SM80+ (A100 or newer).", +) + +# Disable TF32 to keep the fp32 reference path bit-comparable across runs. +torch.backends.cuda.matmul.allow_tf32 = False + + +@dataclass +class MLAConfig: + b: int + h: int + s_q: int + s_kv: int + d_qk: int + d_v: int + + @staticmethod + def desc(cfg): + return f"b{cfg.b}_h{cfg.h}_sq{cfg.s_q}_skv{cfg.s_kv}_dqk{cfg.d_qk}_dv{cfg.d_v}" + + +mla_configs = [ + # square head dim sanity + MLAConfig(2, 4, 128, 128, 64, 64), + MLAConfig(2, 4, 256, 256, 128, 128), + # DeepSeek-V2 prefill shape + MLAConfig(2, 8, 512, 512, 192, 128), + # bigger seq + MLAConfig(1, 16, 2048, 2048, 192, 128), + # cross-attention-style: S_q != S_kv + MLAConfig(2, 8, 512, 1024, 192, 128), + # non-multiple-of-block seq lengths + MLAConfig(2, 4, 513, 513, 192, 128), + # decode-shaped Q (still using the prefill kernel) + MLAConfig(1, 4, 1, 1024, 192, 128), +] + + +def _tols(dtype): + if dtype == torch.bfloat16: + return dict(atol=2.5e-2, rtol=2.5e-2) + return dict(atol=5e-3, rtol=5e-3) + + +def _make_qkv(cfg, dtype, qkv_format): + """Allocate (q, k, v) in the requested layout. Sized so values stay well within the + kernel's dynamic range (avoids saturating exp in fp32 accumulator).""" + scale = 0.5 + if qkv_format == "bhsd": + q = torch.randn(cfg.b, cfg.h, cfg.s_q, cfg.d_qk, device="cuda", dtype=dtype) * scale + k = torch.randn(cfg.b, cfg.h, cfg.s_kv, cfg.d_qk, device="cuda", dtype=dtype) * scale + v = torch.randn(cfg.b, cfg.h, cfg.s_kv, cfg.d_v, device="cuda", dtype=dtype) * scale + elif qkv_format == "bshd": + q = torch.randn(cfg.b, cfg.s_q, cfg.h, cfg.d_qk, device="cuda", dtype=dtype) * scale + k = torch.randn(cfg.b, cfg.s_kv, cfg.h, cfg.d_qk, device="cuda", dtype=dtype) * scale + v = torch.randn(cfg.b, cfg.s_kv, cfg.h, cfg.d_v, device="cuda", dtype=dtype) * scale + elif qkv_format == "sbhd": + q = torch.randn(cfg.s_q, cfg.b, cfg.h, cfg.d_qk, device="cuda", dtype=dtype) * scale + k = torch.randn(cfg.s_kv, cfg.b, cfg.h, cfg.d_qk, device="cuda", dtype=dtype) * scale + v = torch.randn(cfg.s_kv, cfg.b, cfg.h, cfg.d_v, device="cuda", dtype=dtype) * scale + else: + raise ValueError(qkv_format) + return q, k, v + + +def _ref_in_user_layout(q, k, v, qkv_format, *, softmax_scale=None, is_causal=False): + """Run mla_attention_ref on tensors that are in user layout. + + The reference operates in BHSD; we transpose, run, and transpose back. + """ + if qkv_format == "bshd": + q_b = q.transpose(1, 2).contiguous() + k_b = k.transpose(1, 2).contiguous() + v_b = v.transpose(1, 2).contiguous() + elif qkv_format == "sbhd": + q_b = q.permute(1, 2, 0, 3).contiguous() + k_b = k.permute(1, 2, 0, 3).contiguous() + v_b = v.permute(1, 2, 0, 3).contiguous() + else: + q_b, k_b, v_b = q, k, v + out_bhsd = mla_attention_ref(q_b, k_b, v_b, softmax_scale=softmax_scale, is_causal=is_causal) + if qkv_format == "bshd": + return out_bhsd.transpose(1, 2).contiguous() + if qkv_format == "sbhd": + return out_bhsd.permute(2, 0, 1, 3).contiguous() + return out_bhsd + + +@pytest.mark.parametrize("cfg", mla_configs, ids=MLAConfig.desc) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +@pytest.mark.parametrize("is_causal", [False, True], ids=["nocausal", "causal"]) +@pytest.mark.parametrize("qkv_format", ["bshd", "bhsd", "sbhd"]) +def test_mla_forward(cfg, dtype, is_causal, qkv_format): + reset_rng_states() + q, k, v = _make_qkv(cfg, dtype, qkv_format) + + out_triton = mla_attention(q, k, v, is_causal=is_causal, qkv_format=qkv_format) + out_ref = _ref_in_user_layout(q, k, v, qkv_format, is_causal=is_causal) + + assert out_triton.shape == out_ref.shape + assert out_triton.dtype == out_ref.dtype + torch.testing.assert_close(out_triton, out_ref, **_tols(dtype)) + + +_backward_configs = [ + MLAConfig(2, 4, 128, 128, 64, 64), + MLAConfig(2, 4, 256, 256, 128, 128), + MLAConfig(2, 8, 256, 256, 192, 128), # DeepSeek-V2 dims, smaller seq + MLAConfig(2, 8, 512, 512, 192, 128), # DeepSeek-V2 prefill + MLAConfig(2, 4, 256, 512, 192, 128), # cross-attn shape + MLAConfig(2, 4, 257, 257, 192, 128), # non-multiple-of-block seq +] + + +@pytest.mark.parametrize("cfg", _backward_configs, ids=MLAConfig.desc) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +@pytest.mark.parametrize("is_causal", [False, True], ids=["nocausal", "causal"]) +def test_mla_backward_matches_reference(cfg, dtype, is_causal): + """Triton-computed dQ/dK/dV must match the pure-PyTorch reference within bf16/fp16 tolerances.""" + reset_rng_states() + q_base, k_base, v_base = _make_qkv(cfg, dtype, qkv_format="bshd") + q = q_base.detach().clone().requires_grad_(True) + k = k_base.detach().clone().requires_grad_(True) + v = v_base.detach().clone().requires_grad_(True) + q_ref = q_base.detach().clone().requires_grad_(True) + k_ref = k_base.detach().clone().requires_grad_(True) + v_ref = v_base.detach().clone().requires_grad_(True) + + out_triton = mla_attention(q, k, v, is_causal=is_causal, qkv_format="bshd") + out_ref = _ref_in_user_layout(q_ref, k_ref, v_ref, "bshd", is_causal=is_causal) + + # Smaller-magnitude grad_o so dS = P*(dP - Delta) stays well within the + # accumulator's representable range for the non-square head-dim cases. + grad_o = torch.randn_like(out_triton) * 0.1 + out_triton.backward(grad_o) + out_ref.backward(grad_o) + + tols = _tols(dtype) + torch.testing.assert_close(q.grad, q_ref.grad, **tols) + torch.testing.assert_close(k.grad, k_ref.grad, **tols) + torch.testing.assert_close(v.grad, v_ref.grad, **tols) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +@pytest.mark.parametrize("is_causal", [False, True], ids=["nocausal", "causal"]) +def test_mla_layout_equivalence(dtype, is_causal): + """Same logical inputs in BSHD and BHSD must produce equal outputs (modulo layout).""" + reset_rng_states() + cfg = MLAConfig(2, 4, 256, 256, 192, 128) + q_b, k_b, v_b = _make_qkv(cfg, dtype, qkv_format="bhsd") + q_s = q_b.transpose(1, 2).contiguous() + k_s = k_b.transpose(1, 2).contiguous() + v_s = v_b.transpose(1, 2).contiguous() + + out_bhsd = mla_attention(q_b, k_b, v_b, is_causal=is_causal, qkv_format="bhsd") + out_bshd = mla_attention(q_s, k_s, v_s, is_causal=is_causal, qkv_format="bshd") + + # Triton kernel runs on the canonicalized BHSD path in both cases, so outputs + # must match exactly after layout permutation. + torch.testing.assert_close(out_bhsd, out_bshd.transpose(1, 2).contiguous(), atol=0.0, rtol=0.0) + + +def test_mla_softmax_scale_default(): + """Verify the default softmax_scale is 1/sqrt(head_dim_qk) (Q-side, not V-side).""" + reset_rng_states() + cfg = MLAConfig(1, 2, 64, 64, 192, 128) + q, k, v = _make_qkv(cfg, torch.bfloat16, "bshd") + out_default = mla_attention(q, k, v, qkv_format="bshd") + out_explicit = mla_attention(q, k, v, softmax_scale=cfg.d_qk**-0.5, qkv_format="bshd") + torch.testing.assert_close(out_default, out_explicit, atol=0.0, rtol=0.0) + + +def test_mla_rejects_fp32_input(): + """The Triton kernel is fp16/bf16 only — fp32 inputs should raise.""" + cfg = MLAConfig(1, 2, 64, 64, 64, 64) + q, k, v = _make_qkv(cfg, torch.float32, "bshd") + with pytest.raises(ValueError, match="fp16 or bf16"): + mla_attention(q, k, v, qkv_format="bshd") + + +# --------------------------------------------------------------------------- +# Decode (absorbed up-projection) tests +# --------------------------------------------------------------------------- + + +@dataclass +class MLADecodeConfig: + b: int + h: int + s_q: int + s_kv: int + r: int # kv_lora_rank + r_rope: int + + @staticmethod + def desc(cfg): + return f"b{cfg.b}_h{cfg.h}_sq{cfg.s_q}_skv{cfg.s_kv}_r{cfg.r}_rrope{cfg.r_rope}" + + +_decode_configs = [ + MLADecodeConfig(1, 4, 1, 128, 64, 16), # smoke + MLADecodeConfig(1, 4, 1, 512, 64, 16), # bigger Skv + MLADecodeConfig(1, 4, 4, 512, 64, 16), # multi-token / speculative decode + MLADecodeConfig(2, 4, 128, 128, 128, 32), # prefill via decode kernel + # DeepSeek-V2 dims (kv_lora_rank=512, rope_dim=64). + MLADecodeConfig(1, 8, 1, 1024, 512, 64), + MLADecodeConfig(1, 16, 1, 2048, 512, 64), + MLADecodeConfig(1, 4, 1, 257, 512, 64), # non-multiple-of-block S_kv +] + + +def _make_decode_inputs(cfg, dtype, scale=0.5): + qn = torch.randn(cfg.b, cfg.h, cfg.s_q, cfg.r, device="cuda", dtype=dtype) * scale + qr = torch.randn(cfg.b, cfg.h, cfg.s_q, cfg.r_rope, device="cuda", dtype=dtype) * scale + ck = torch.randn(cfg.b, cfg.s_kv, cfg.r, device="cuda", dtype=dtype) * scale + kr = torch.randn(cfg.b, cfg.s_kv, cfg.r_rope, device="cuda", dtype=dtype) * scale + return qn, qr, ck, kr + + +@pytest.mark.parametrize("cfg", _decode_configs, ids=MLADecodeConfig.desc) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +@pytest.mark.parametrize("is_causal", [False, True], ids=["nocausal", "causal"]) +def test_mla_decode_forward(cfg, dtype, is_causal): + reset_rng_states() + qn, qr, ck, kr = _make_decode_inputs(cfg, dtype) + softmax_scale = 1.0 / (cfg.r + cfg.r_rope) ** 0.5 + + o = mla_decode_attention(qn, qr, ck, kr, softmax_scale=softmax_scale, is_causal=is_causal) + o_ref = mla_decode_attention_ref( + qn, qr, ck, kr, softmax_scale=softmax_scale, is_causal=is_causal + ) + + assert o.shape == o_ref.shape == (cfg.b, cfg.h, cfg.s_q, cfg.r) + assert o.dtype == o_ref.dtype == dtype + torch.testing.assert_close(o, o_ref, **_tols(dtype)) + + +def test_mla_decode_rejects_fp32_input(): + cfg = MLADecodeConfig(1, 2, 1, 64, 64, 16) + qn, qr, ck, kr = _make_decode_inputs(cfg, torch.float32) + with pytest.raises(ValueError, match="fp16 or bf16"): + mla_decode_attention(qn, qr, ck, kr, softmax_scale=0.1) + + +def test_mla_decode_rejects_dim_mismatch(): + cfg = MLADecodeConfig(1, 2, 1, 64, 64, 16) + qn, qr, ck, kr = _make_decode_inputs(cfg, torch.bfloat16) + # Mismatched kv_lora_rank between Q-side and cache. + bad_ck = torch.randn(1, 64, 32, device="cuda", dtype=torch.bfloat16) + with pytest.raises(ValueError, match="kv_lora_rank"): + mla_decode_attention(qn, qr, bad_ck, kr, softmax_scale=0.1) + + +def test_mla_decode_rejects_requires_grad(): + """Decode is forward-only in v1; an input with requires_grad=True must + raise NotImplementedError rather than silently dropping gradients at the + kernel boundary.""" + cfg = MLADecodeConfig(1, 2, 1, 64, 64, 16) + qn, qr, ck, kr = _make_decode_inputs(cfg, torch.bfloat16) + qn.requires_grad_(True) + with pytest.raises(NotImplementedError, match="forward-only"): + mla_decode_attention(qn, qr, ck, kr, softmax_scale=0.1) + + +# --------------------------------------------------------------------------- +# DotProductAttention dispatch (NVTE_MLA_TRITON=1 fast path) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +@pytest.mark.parametrize("is_causal", [False, True], ids=["nocausal", "causal"]) +@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd"]) +def test_dpa_dispatches_to_mla_triton(monkeypatch, dtype, is_causal, qkv_format): + """With NVTE_MLA_TRITON=1, an MLA-shaped DotProductAttention call must + produce the same output as a direct ``mla_attention`` call.""" + import transformer_engine.pytorch as te + + reset_rng_states() + cfg = MLAConfig(2, 4, 256, 256, 192, 128) + q, k, v = _make_qkv(cfg, dtype, qkv_format=qkv_format) + + monkeypatch.setenv("NVTE_MLA_TRITON", "1") + + # Force backend cache invalidation so the env-var check re-runs. + from transformer_engine.pytorch.attention.dot_product_attention import ( + dot_product_attention as dpa_mod, + ) + + dpa_mod._attention_backends["backend_selection_requires_update"] = True + + softmax_scale = cfg.d_qk**-0.5 + dpa = te.DotProductAttention( + num_attention_heads=cfg.h, + kv_channels=(cfg.d_qk, cfg.d_v), + attention_dropout=0.0, + softmax_scale=softmax_scale, + qkv_format=qkv_format, + attn_mask_type="causal" if is_causal else "no_mask", + ).cuda() + + out_dpa = dpa(q, k, v) + out_direct = mla_attention( + q, k, v, softmax_scale=softmax_scale, is_causal=is_causal, qkv_format=qkv_format + ) + torch.testing.assert_close(out_dpa, out_direct, atol=0.0, rtol=0.0) + + +def test_dpa_falls_through_when_env_var_unset(monkeypatch): + """With NVTE_MLA_TRITON unset, dispatch must NOT route through the MLA + Triton backend even for MLA-shaped inputs (existing behavior preserved).""" + import transformer_engine.pytorch as te + from transformer_engine.pytorch.attention.dot_product_attention import ( + dot_product_attention as dpa_mod, + ) + + reset_rng_states() + cfg = MLAConfig(2, 4, 256, 256, 192, 128) + q, k, v = _make_qkv(cfg, torch.bfloat16, qkv_format="bshd") + + monkeypatch.delenv("NVTE_MLA_TRITON", raising=False) + dpa_mod._attention_backends["backend_selection_requires_update"] = True + + dpa = te.DotProductAttention( + num_attention_heads=cfg.h, + kv_channels=(cfg.d_qk, cfg.d_v), + attention_dropout=0.0, + softmax_scale=cfg.d_qk**-0.5, + qkv_format="bshd", + attn_mask_type="causal", + ).cuda() + + # Just exercise it — any cuDNN/FA backend may be selected. We only assert + # no exception (i.e. the early-out wasn't accidentally triggered). + _ = dpa(q, k, v) diff --git a/transformer_engine/common/triton/mla.py b/transformer_engine/common/triton/mla.py new file mode 100644 index 0000000000..f6a279606b --- /dev/null +++ b/transformer_engine/common/triton/mla.py @@ -0,0 +1,717 @@ +# pylint: disable=missing-function-docstring + +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Triton kernels for MLA (Multi-head Latent Attention). + +Two operating modes: + +1. **Prefill / training** (``_mla_attn_fwd`` + backward kernels): + FlashAttention-2 style with non-square head dimensions + (``head_dim_qk != head_dim_v`` as used by DeepSeek-V2/V3). Inputs are full + Q, K, V tensors. Backward uses the canonical FA-2 three-pass structure + (preprocess for ``Delta = rowsum(O * dO)``, then ``dQ`` and ``dK``/``dV`` + kernels). No atomics — each program owns a distinct output slice. + +2. **Absorbed-projection decode** (``_mla_decode_attn_fwd``): + Operates on a compressed KV cache (latent ``c_kv`` of dim ``kv_lora_rank`` + plus a decoupled rope key ``k_rope``). Q's nope side is pre-absorbed via + ``W_uk^T`` so the kernel never materializes per-head K or V; ``c_kv`` + serves as both the key (nope side) and the value. Output is the + pre-``W_uv`` intermediate ``O_inter`` of shape ``[B, H, S_q, kv_lora_rank]``; + the caller applies ``W_uv`` to produce the final attention output. + +Tuned for SM80 (A100). Compiles on SM89/SM90 too but is not specialized for them. +""" + +import itertools +import os + +import triton +import triton.language as tl + + +def _mla_fwd_configs(): + block_m = [64, 128] + block_n = [32, 64] + num_warps = [4, 8] + num_stages = [2, 3] + + configs = [] + for m, n, w, s in itertools.product(block_m, block_n, num_warps, num_stages): + configs.append( + triton.Config( + {"BLOCK_M": m, "BLOCK_N": n}, + num_warps=w, + num_stages=s, + ) + ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +@triton.autotune( + configs=_mla_fwd_configs(), + key=["S_q", "S_kv", "D_qk", "D_v", "IS_CAUSAL"], +) +@triton.jit +def _mla_attn_fwd( + Q_ptr, # (B, H, S_q, D_qk) + K_ptr, # (B, H, S_kv, D_qk) + V_ptr, # (B, H, S_kv, D_v) + O_ptr, # (B, H, S_q, D_v) + LSE_ptr, # (B, H, S_q) fp32 + softmax_scale, + B, + H, + S_q, + S_kv, + D_qk, + D_v, + sQ_b, + sQ_h, + sQ_s, + sQ_d: tl.constexpr, + sK_b, + sK_h, + sK_s, + sK_d: tl.constexpr, + sV_b, + sV_h, + sV_s, + sV_d: tl.constexpr, + sO_b, + sO_h, + sO_s, + sO_d: tl.constexpr, + sL_b, + sL_h, + sL_s: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, +): + """One program -> one BLOCK_M tile of Q rows for one (batch, head). + + Layout: BHSD. Right-aligned causal: row i attends to keys j s.t. + ``j <= i + (S_kv - S_q)``. + """ + pid_m = tl.program_id(0) + pid_bh = tl.program_id(1) + tl.assume(pid_m >= 0) + tl.assume(pid_bh >= 0) + + off_b = pid_bh // H + off_h = pid_bh % H + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n_init = tl.arange(0, BLOCK_N) + offs_d_qk = tl.arange(0, BLOCK_DMODEL_QK) + offs_d_v = tl.arange(0, BLOCK_DMODEL_V) + + mask_m = offs_m < S_q + mask_d_qk = offs_d_qk < D_qk + mask_d_v = offs_d_v < D_v + + # Load Q tile once (resident across the inner K/V loop). + q_base = Q_ptr + off_b * sQ_b + off_h * sQ_h + q_ptrs = q_base + offs_m[:, None] * sQ_s + offs_d_qk[None, :] * sQ_d + q = tl.load(q_ptrs, mask=mask_m[:, None] & mask_d_qk[None, :], other=0.0) + + k_base = K_ptr + off_b * sK_b + off_h * sK_h + v_base = V_ptr + off_b * sV_b + off_h * sV_h + + # ``m_i`` is initialized to a finite "very negative" sentinel rather than -inf + # so that exp(m_i - m_ij) is well-defined even when an entire K block is masked + # (every qk == -inf, e.g. the rare S_q > S_kv + causal case). -1e6 is large + # enough that exp(-1e6 - any_realistic_qk) underflows to 0 in fp32. + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - 1.0e6 + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=tl.float32) + + # Loop bound: standard FA-2 early exit for causal. No clamp to 0 needed — + # ``range`` with a non-positive upper bound is empty. + if IS_CAUSAL: + hi = tl.minimum((pid_m + 1) * BLOCK_M + (S_kv - S_q), S_kv) + else: + hi = S_kv + + for start_n in range(0, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + offs_n = start_n + offs_n_init + mask_n = offs_n < S_kv + + # Load K tile [BLOCK_N, BLOCK_DMODEL_QK]. + k_ptrs = k_base + offs_n[:, None] * sK_s + offs_d_qk[None, :] * sK_d + k = tl.load(k_ptrs, mask=mask_n[:, None] & mask_d_qk[None, :], other=0.0) + + # qk = Q @ K^T -> [BLOCK_M, BLOCK_N] (fp32 accum) + qk = tl.dot(q, tl.trans(k)) + qk = qk * softmax_scale + + # Mask invalid keys (K-tail past S_kv) to -inf BEFORE softmax max/exp. + qk = tl.where(mask_n[None, :], qk, float("-inf")) + if IS_CAUSAL: + causal_mask = offs_n[None, :] <= (offs_m[:, None] + (S_kv - S_q)) + qk = tl.where(causal_mask, qk, float("-inf")) + + # Online softmax. ``m_i`` is initialized to a finite sentinel so that + # exp(m_i - m_ij) is well-defined even when this entire block is masked + # (all qk == -inf). + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.exp(m_i - m_ij) + p = tl.exp(qk - m_ij[:, None]) + l_i = l_i * alpha + tl.sum(p, 1) + acc = acc * alpha[:, None] + + # Load V tile [BLOCK_N, BLOCK_DMODEL_V] and accumulate P @ V. + v_ptrs = v_base + offs_n[:, None] * sV_s + offs_d_v[None, :] * sV_d + v = tl.load(v_ptrs, mask=mask_n[:, None] & mask_d_v[None, :], other=0.0) + acc = tl.dot(p.to(v.dtype), v, acc) + + m_i = m_ij + + # Epilogue: normalize and store O in input dtype. + acc = acc / l_i[:, None] + o_base = O_ptr + off_b * sO_b + off_h * sO_h + o_ptrs = o_base + offs_m[:, None] * sO_s + offs_d_v[None, :] * sO_d + tl.store( + o_ptrs, + acc.to(O_ptr.dtype.element_ty), + mask=mask_m[:, None] & mask_d_v[None, :], + ) + + # Store fp32 LSE = m_i + log(l_i) (used by the analytical backward). + lse_base = LSE_ptr + off_b * sL_b + off_h * sL_h + lse_ptrs = lse_base + offs_m * sL_s + tl.store(lse_ptrs, m_i + tl.log(l_i), mask=mask_m) + + +# --------------------------------------------------------------------------- +# Backward kernels (FA-2 style, three passes, no atomics) +# --------------------------------------------------------------------------- + + +def _mla_bwd_configs(): + block_m = [64, 128] + block_n = [32, 64] + num_warps = [4, 8] + num_stages = [2, 3] + configs = [] + for m, n, w, s in itertools.product(block_m, block_n, num_warps, num_stages): + configs.append( + triton.Config( + {"BLOCK_M": m, "BLOCK_N": n}, + num_warps=w, + num_stages=s, + ) + ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +@triton.jit +def _mla_attn_bwd_preprocess( + O_ptr, # (B, H, S_q, D_v) + DO_ptr, # (B, H, S_q, D_v) + Delta_ptr, # (B, H, S_q) fp32 output + B, + H, + S_q, + D_v, + sO_b, + sO_h, + sO_s, + sO_d: tl.constexpr, + sDO_b, + sDO_h, + sDO_s, + sDO_d: tl.constexpr, + sD_b, + sD_h, + sD_s: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, +): + """Compute ``Delta[b,h,m] = sum_d O[b,h,m,d] * dO[b,h,m,d]`` in fp32.""" + pid_m = tl.program_id(0) + pid_bh = tl.program_id(1) + tl.assume(pid_m >= 0) + tl.assume(pid_bh >= 0) + + off_b = pid_bh // H + off_h = pid_bh % H + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_DMODEL_V) + mask_m = offs_m < S_q + mask_d = offs_d < D_v + mask_md = mask_m[:, None] & mask_d[None, :] + + o_ptrs = O_ptr + off_b * sO_b + off_h * sO_h + offs_m[:, None] * sO_s + offs_d[None, :] * sO_d + do_ptrs = ( + DO_ptr + off_b * sDO_b + off_h * sDO_h + offs_m[:, None] * sDO_s + offs_d[None, :] * sDO_d + ) + o = tl.load(o_ptrs, mask=mask_md, other=0.0).to(tl.float32) + do = tl.load(do_ptrs, mask=mask_md, other=0.0).to(tl.float32) + + delta = tl.sum(o * do, axis=1) + delta_ptrs = Delta_ptr + off_b * sD_b + off_h * sD_h + offs_m * sD_s + tl.store(delta_ptrs, delta, mask=mask_m) + + +@triton.autotune( + configs=_mla_bwd_configs(), + key=["S_q", "S_kv", "D_qk", "D_v", "IS_CAUSAL"], +) +@triton.jit +def _mla_attn_bwd_dq( + Q_ptr, # (B, H, S_q, D_qk) + K_ptr, # (B, H, S_kv, D_qk) + V_ptr, # (B, H, S_kv, D_v) + DO_ptr, # (B, H, S_q, D_v) + LSE_ptr, # (B, H, S_q) fp32 + Delta_ptr, # (B, H, S_q) fp32 + DQ_ptr, # (B, H, S_q, D_qk) output + softmax_scale, + B, + H, + S_q, + S_kv, + D_qk, + D_v, + sQ_b, + sQ_h, + sQ_s, + sQ_d: tl.constexpr, + sK_b, + sK_h, + sK_s, + sK_d: tl.constexpr, + sV_b, + sV_h, + sV_s, + sV_d: tl.constexpr, + sDO_b, + sDO_h, + sDO_s, + sDO_d: tl.constexpr, + sLSE_b, + sLSE_h, + sLSE_s: tl.constexpr, + sD_b, + sD_h, + sD_s: tl.constexpr, + sDQ_b, + sDQ_h, + sDQ_s, + sDQ_d: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, +): + """Compute ``dQ`` for one BLOCK_M row tile of one (batch, head).""" + pid_m = tl.program_id(0) + pid_bh = tl.program_id(1) + tl.assume(pid_m >= 0) + tl.assume(pid_bh >= 0) + + off_b = pid_bh // H + off_h = pid_bh % H + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n_init = tl.arange(0, BLOCK_N) + offs_d_qk = tl.arange(0, BLOCK_DMODEL_QK) + offs_d_v = tl.arange(0, BLOCK_DMODEL_V) + + mask_m = offs_m < S_q + mask_d_qk = offs_d_qk < D_qk + mask_d_v = offs_d_v < D_v + + # Load Q, dO, LSE, Delta — resident across the K/V loop. + q_base = Q_ptr + off_b * sQ_b + off_h * sQ_h + q_ptrs = q_base + offs_m[:, None] * sQ_s + offs_d_qk[None, :] * sQ_d + q = tl.load(q_ptrs, mask=mask_m[:, None] & mask_d_qk[None, :], other=0.0) + + do_base = DO_ptr + off_b * sDO_b + off_h * sDO_h + do_ptrs = do_base + offs_m[:, None] * sDO_s + offs_d_v[None, :] * sDO_d + do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d_v[None, :], other=0.0) + + # ``other=+inf`` so that exp(qk - lse) underflows to 0 for invalid Q rows + # and they contribute nothing to dQ. + lse_base = LSE_ptr + off_b * sLSE_b + off_h * sLSE_h + lse = tl.load(lse_base + offs_m * sLSE_s, mask=mask_m, other=float("inf")) + + delta_base = Delta_ptr + off_b * sD_b + off_h * sD_h + delta = tl.load(delta_base + offs_m * sD_s, mask=mask_m, other=0.0) + + dq = tl.zeros([BLOCK_M, BLOCK_DMODEL_QK], dtype=tl.float32) + + if IS_CAUSAL: + n_hi = tl.minimum((pid_m + 1) * BLOCK_M + (S_kv - S_q), S_kv) + else: + n_hi = S_kv + + k_base = K_ptr + off_b * sK_b + off_h * sK_h + v_base = V_ptr + off_b * sV_b + off_h * sV_h + + for start_n in range(0, n_hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + offs_n = start_n + offs_n_init + mask_n = offs_n < S_kv + + k_ptrs = k_base + offs_n[:, None] * sK_s + offs_d_qk[None, :] * sK_d + k = tl.load(k_ptrs, mask=mask_n[:, None] & mask_d_qk[None, :], other=0.0) + v_ptrs = v_base + offs_n[:, None] * sV_s + offs_d_v[None, :] * sV_d + v = tl.load(v_ptrs, mask=mask_n[:, None] & mask_d_v[None, :], other=0.0) + + # Recompute qk and apply same masks as the forward. + qk = tl.dot(q, tl.trans(k)) + qk = qk * softmax_scale + qk = tl.where(mask_n[None, :], qk, float("-inf")) + if IS_CAUSAL: + causal_mask = offs_n[None, :] <= (offs_m[:, None] + (S_kv - S_q)) + qk = tl.where(causal_mask, qk, float("-inf")) + + p = tl.exp(qk - lse[:, None]) + # dP = dO @ V^T -> [BLOCK_M, BLOCK_N] + dp = tl.dot(do, tl.trans(v)) + # dS = P * (dP - Delta) * scale + ds = (p * (dp - delta[:, None])) * softmax_scale + # dQ += dS @ K + dq += tl.dot(ds.to(k.dtype), k) + + dq_base = DQ_ptr + off_b * sDQ_b + off_h * sDQ_h + dq_ptrs = dq_base + offs_m[:, None] * sDQ_s + offs_d_qk[None, :] * sDQ_d + tl.store( + dq_ptrs, + dq.to(DQ_ptr.dtype.element_ty), + mask=mask_m[:, None] & mask_d_qk[None, :], + ) + + +@triton.autotune( + configs=_mla_bwd_configs(), + key=["S_q", "S_kv", "D_qk", "D_v", "IS_CAUSAL"], +) +@triton.jit +def _mla_attn_bwd_dkv( + Q_ptr, + K_ptr, + V_ptr, + DO_ptr, + LSE_ptr, + Delta_ptr, + DK_ptr, # (B, H, S_kv, D_qk) output + DV_ptr, # (B, H, S_kv, D_v) output + softmax_scale, + B, + H, + S_q, + S_kv, + D_qk, + D_v, + sQ_b, + sQ_h, + sQ_s, + sQ_d: tl.constexpr, + sK_b, + sK_h, + sK_s, + sK_d: tl.constexpr, + sV_b, + sV_h, + sV_s, + sV_d: tl.constexpr, + sDO_b, + sDO_h, + sDO_s, + sDO_d: tl.constexpr, + sLSE_b, + sLSE_h, + sLSE_s: tl.constexpr, + sD_b, + sD_h, + sD_s: tl.constexpr, + sDK_b, + sDK_h, + sDK_s, + sDK_d: tl.constexpr, + sDV_b, + sDV_h, + sDV_s, + sDV_d: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, +): + """Compute ``dK`` and ``dV`` for one BLOCK_N tile of one (batch, head).""" + pid_n = tl.program_id(0) + pid_bh = tl.program_id(1) + tl.assume(pid_n >= 0) + tl.assume(pid_bh >= 0) + + off_b = pid_bh // H + off_h = pid_bh % H + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m_init = tl.arange(0, BLOCK_M) + offs_d_qk = tl.arange(0, BLOCK_DMODEL_QK) + offs_d_v = tl.arange(0, BLOCK_DMODEL_V) + + mask_n = offs_n < S_kv + mask_d_qk = offs_d_qk < D_qk + mask_d_v = offs_d_v < D_v + + # Load K, V tiles — resident across the Q loop. + k_base = K_ptr + off_b * sK_b + off_h * sK_h + k_ptrs = k_base + offs_n[:, None] * sK_s + offs_d_qk[None, :] * sK_d + k = tl.load(k_ptrs, mask=mask_n[:, None] & mask_d_qk[None, :], other=0.0) + + v_base = V_ptr + off_b * sV_b + off_h * sV_h + v_ptrs = v_base + offs_n[:, None] * sV_s + offs_d_v[None, :] * sV_d + v = tl.load(v_ptrs, mask=mask_n[:, None] & mask_d_v[None, :], other=0.0) + + dk = tl.zeros([BLOCK_N, BLOCK_DMODEL_QK], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_DMODEL_V], dtype=tl.float32) + + # Causal: only Q rows i with j_max(i) >= n_start contribute. Round the + # lower bound *down* to a multiple of BLOCK_M so loads stay aligned; the + # extra rows iterated are masked out per-element. + if IS_CAUSAL: + m_lo = pid_n * BLOCK_N - (S_kv - S_q) + m_lo = tl.maximum(m_lo, 0) + m_lo = (m_lo // BLOCK_M) * BLOCK_M + else: + m_lo = 0 + + for start_m in range(m_lo, S_q, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m = start_m + offs_m_init + mask_m = offs_m < S_q + + q_base = Q_ptr + off_b * sQ_b + off_h * sQ_h + q_ptrs = q_base + offs_m[:, None] * sQ_s + offs_d_qk[None, :] * sQ_d + q = tl.load(q_ptrs, mask=mask_m[:, None] & mask_d_qk[None, :], other=0.0) + + # Recompute qk with the forward's mask convention. + qk = tl.dot(q, tl.trans(k)) + qk = qk * softmax_scale + qk = tl.where(mask_n[None, :], qk, float("-inf")) + if IS_CAUSAL: + causal_mask = offs_n[None, :] <= (offs_m[:, None] + (S_kv - S_q)) + qk = tl.where(causal_mask, qk, float("-inf")) + + # Invalid Q rows: load LSE with ``+inf`` so P underflows to 0 and these + # rows contribute nothing to dK/dV. + lse_base = LSE_ptr + off_b * sLSE_b + off_h * sLSE_h + lse = tl.load(lse_base + offs_m * sLSE_s, mask=mask_m, other=float("inf")) + delta_base = Delta_ptr + off_b * sD_b + off_h * sD_h + delta = tl.load(delta_base + offs_m * sD_s, mask=mask_m, other=0.0) + + do_base = DO_ptr + off_b * sDO_b + off_h * sDO_h + do_ptrs = do_base + offs_m[:, None] * sDO_s + offs_d_v[None, :] * sDO_d + do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d_v[None, :], other=0.0) + + p = tl.exp(qk - lse[:, None]) # [BLOCK_M, BLOCK_N] + + # dV += P^T @ dO + dv += tl.dot(tl.trans(p).to(do.dtype), do) + + # dP = dO @ V^T -> [BLOCK_M, BLOCK_N] + dp = tl.dot(do, tl.trans(v)) + ds = (p * (dp - delta[:, None])) * softmax_scale + # dK += dS^T @ Q + dk += tl.dot(tl.trans(ds).to(q.dtype), q) + + dk_base = DK_ptr + off_b * sDK_b + off_h * sDK_h + dk_ptrs = dk_base + offs_n[:, None] * sDK_s + offs_d_qk[None, :] * sDK_d + tl.store( + dk_ptrs, + dk.to(DK_ptr.dtype.element_ty), + mask=mask_n[:, None] & mask_d_qk[None, :], + ) + + dv_base = DV_ptr + off_b * sDV_b + off_h * sDV_h + dv_ptrs = dv_base + offs_n[:, None] * sDV_s + offs_d_v[None, :] * sDV_d + tl.store( + dv_ptrs, + dv.to(DV_ptr.dtype.element_ty), + mask=mask_n[:, None] & mask_d_v[None, :], + ) + + +# --------------------------------------------------------------------------- +# Decode forward kernel (absorbed up-projection over compressed KV cache) +# --------------------------------------------------------------------------- + + +def _mla_decode_fwd_configs(): + # Decode tiles are smaller than prefill because the effective K dim is + # ``kv_lora_rank`` (e.g. 512) which is much larger than a normal head dim, + # so SMEM pressure is high. Triton's autotune will silently prune configs + # that exceed the SM80 SMEM budget. + block_m = [16, 32, 64] + block_n = [16, 32, 64] + num_warps = [4, 8] + num_stages = [2, 3] + configs = [] + for m, n, w, s in itertools.product(block_m, block_n, num_warps, num_stages): + configs.append( + triton.Config( + {"BLOCK_M": m, "BLOCK_N": n}, + num_warps=w, + num_stages=s, + ) + ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +@triton.autotune( + configs=_mla_decode_fwd_configs(), + key=["S_q", "S_kv", "R", "R_rope", "IS_CAUSAL"], +) +@triton.jit +def _mla_decode_attn_fwd( + QN_ptr, # (B, H, S_q, R) Q_nope_abs (Q_nope already multiplied by W_uk^T) + QR_ptr, # (B, H, S_q, R_rope) Q_rope + CKV_ptr, # (B, S_kv, R) compressed KV cache (shared across heads) + KR_ptr, # (B, S_kv, R_rope) decoupled rope key (shared across heads) + O_ptr, # (B, H, S_q, R) O_inter (caller applies W_uv) + softmax_scale, + B, + H, + S_q, + S_kv, + R, + R_rope, + # Q_nope_abs strides + sQN_b, + sQN_h, + sQN_s, + sQN_r: tl.constexpr, + # Q_rope strides + sQR_b, + sQR_h, + sQR_s, + sQR_r: tl.constexpr, + # c_kv strides (no H) + sCKV_b, + sCKV_s, + sCKV_r: tl.constexpr, + # k_rope strides (no H) + sKR_b, + sKR_s, + sKR_r: tl.constexpr, + # O_inter strides + sO_b, + sO_h, + sO_s, + sO_r: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL_R: tl.constexpr, # next_pow2(R) + BLOCK_DMODEL_RR: tl.constexpr, # next_pow2(R_rope) +): + """Absorbed MLA decode forward — one BLOCK_M tile of Q rows for one (b, h). + + Computes ``score = Q_nope_abs @ c_kv^T + Q_rope @ k_rope^T`` (scaled), + softmax, then ``O_inter = P @ c_kv``. ``c_kv`` is reused as both K (for the + nope-side score) and V; per-head ``K_nope`` / ``V`` are never materialized. + """ + pid_m = tl.program_id(0) + pid_bh = tl.program_id(1) + tl.assume(pid_m >= 0) + tl.assume(pid_bh >= 0) + + off_b = pid_bh // H + off_h = pid_bh % H + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n_init = tl.arange(0, BLOCK_N) + offs_r = tl.arange(0, BLOCK_DMODEL_R) + offs_rr = tl.arange(0, BLOCK_DMODEL_RR) + + mask_m = offs_m < S_q + mask_r = offs_r < R + mask_rr = offs_rr < R_rope + + # Load Q_nope_abs and Q_rope tiles once. + qn_base = QN_ptr + off_b * sQN_b + off_h * sQN_h + qn_ptrs = qn_base + offs_m[:, None] * sQN_s + offs_r[None, :] * sQN_r + qn = tl.load(qn_ptrs, mask=mask_m[:, None] & mask_r[None, :], other=0.0) + + qr_base = QR_ptr + off_b * sQR_b + off_h * sQR_h + qr_ptrs = qr_base + offs_m[:, None] * sQR_s + offs_rr[None, :] * sQR_r + qr = tl.load(qr_ptrs, mask=mask_m[:, None] & mask_rr[None, :], other=0.0) + + ckv_base = CKV_ptr + off_b * sCKV_b + kr_base = KR_ptr + off_b * sKR_b + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - 1.0e6 + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_R], dtype=tl.float32) + + if IS_CAUSAL: + hi = tl.minimum((pid_m + 1) * BLOCK_M + (S_kv - S_q), S_kv) + else: + hi = S_kv + + for start_n in range(0, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + offs_n = start_n + offs_n_init + mask_n = offs_n < S_kv + + # c_kv tile [BLOCK_N, BLOCK_DMODEL_R] — used for both the nope-side + # score and as the value tile. + ckv_ptrs = ckv_base + offs_n[:, None] * sCKV_s + offs_r[None, :] * sCKV_r + ckv = tl.load(ckv_ptrs, mask=mask_n[:, None] & mask_r[None, :], other=0.0) + + # k_rope tile [BLOCK_N, BLOCK_DMODEL_RR] + kr_ptrs = kr_base + offs_n[:, None] * sKR_s + offs_rr[None, :] * sKR_r + kr = tl.load(kr_ptrs, mask=mask_n[:, None] & mask_rr[None, :], other=0.0) + + # Scores: nope side is Q_nope_abs @ c_kv^T, rope side is Q_rope @ k_rope^T. + score_nope = tl.dot(qn, tl.trans(ckv)) + score_rope = tl.dot(qr, tl.trans(kr)) + qk = (score_nope + score_rope) * softmax_scale + + qk = tl.where(mask_n[None, :], qk, float("-inf")) + if IS_CAUSAL: + causal_mask = offs_n[None, :] <= (offs_m[:, None] + (S_kv - S_q)) + qk = tl.where(causal_mask, qk, float("-inf")) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.exp(m_i - m_ij) + p = tl.exp(qk - m_ij[:, None]) + l_i = l_i * alpha + tl.sum(p, 1) + acc = acc * alpha[:, None] + # Reuse the c_kv tile as V: O_inter += P @ c_kv. + acc = tl.dot(p.to(ckv.dtype), ckv, acc) + + m_i = m_ij + + acc = acc / l_i[:, None] + o_base = O_ptr + off_b * sO_b + off_h * sO_h + o_ptrs = o_base + offs_m[:, None] * sO_s + offs_r[None, :] * sO_r + tl.store( + o_ptrs, + acc.to(O_ptr.dtype.element_ty), + mask=mask_m[:, None] & mask_r[None, :], + ) + # LSE is intentionally not saved here — v1 has no analytical decode + # backward. If a future change adds one, re-introduce a fp32 LSE buffer + # and ``m_i + tl.log(l_i)`` store at this point. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 4104820a1c..5a0a72c47a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -2175,3 +2175,42 @@ def forward( return output[0].view(*output[0].shape[:-2], -1), output[1] # ...hd -> ...(hd) return output.view(*output.shape[:-2], -1) + + +class MLATritonAttention(torch.nn.Module): + """Triton MLA backend wrapping :func:`transformer_engine.pytorch.triton.mla.mla_attention`. + + Opt-in path for SM80 MLA-shaped attention (``head_dim_qk != head_dim_v``). + Activated by setting ``NVTE_MLA_TRITON=1``; see + :class:`DotProductAttention` for the dispatch conditions. Inputs are + expected to already be in ``[B, S, H, D]`` (``bshd``) or + ``[S, B, H, D]`` (``sbhd``) layout — ``thd`` and ``no_mask``-padding + variants fall through to the regular cascade. + """ + + def __init__(self, softmax_scale: float, attention_dropout: float = 0.0) -> None: + super().__init__() + self.softmax_scale = softmax_scale + if attention_dropout != 0.0: + raise ValueError("MLATritonAttention does not support attention_dropout > 0.") + + def forward( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + qkv_format: str, + is_causal: bool, + ) -> torch.Tensor: + # Imported lazily so that environments without triton.mla don't pay + # the import cost when this backend is never selected. + from transformer_engine.pytorch.triton.mla import mla_attention + + return mla_attention( + query_layer, + key_layer, + value_layer, + softmax_scale=self.softmax_scale, + is_causal=is_causal, + qkv_format=qkv_format, + ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 17e9a337a4..d3809b48a9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -60,6 +60,7 @@ UnfusedDotProductAttention, FusedAttention, FlashAttention, + MLATritonAttention, ) @@ -505,6 +506,19 @@ def __init__( return_max_logit=self.return_max_logit, ) + # Triton MLA backend for SM80 (A100). Opt-in via NVTE_MLA_TRITON=1; see + # :meth:`forward` for the dispatch conditions. Always instantiated so + # the env var can be flipped at runtime; the underlying triton kernel + # is JIT-compiled lazily on first invocation. + self.mla_triton_attention = ( + MLATritonAttention( + softmax_scale, + attention_dropout=attention_dropout, + ) + if attention_dropout == 0.0 + else None + ) + def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ Temporarily remove core_attention._extra_state as a missing key @@ -1500,6 +1514,49 @@ def forward( " disabling all backends." ) + # Optional MLA Triton fast path. Opt-in via NVTE_MLA_TRITON=1 and + # only activated when every supported precondition holds. Any + # mismatch (FP8, dropout, CP, sliding window, alibi, padding mask, + # cross-attn with top-left causal, etc.) silently falls through to + # the regular FA/Fused/Unfused cascade — defaulting to no behavior + # change when the env var is unset. + if ( + os.getenv("NVTE_MLA_TRITON", "0") == "1" + and self.mla_triton_attention is not None + and head_dim_qk != head_dim_v + and head_dim_qk in (128, 192, 256) + and head_dim_v in (64, 128) + and query_layer.dtype in (torch.bfloat16, torch.float16) + and not self.fp8 + and not context_parallel + and core_attention_bias_type == "no_bias" + and alibi_slopes is None + and inference_params is None + and self.softmax_type == "vanilla" + and qkv_format in ("bshd", "sbhd") + and attn_mask_type in (None, "no_mask", "causal", "causal_bottom_right") + and window_size in (None, (-1, -1), (-1, 0)) + and torch.cuda.is_available() + and torch.cuda.get_device_capability(query_layer.device)[0] >= 8 + ): + # Causal masking can be requested via either ``attn_mask_type`` + # (TE's bottom-right alignment when S_q != S_kv matches our + # kernel's right-aligned diagonal) or via FlashAttention-style + # ``window_size=(-1, 0)`` (left=infinite, right=0 → causal). + # Both paths must map to the kernel's ``is_causal=True``. + is_causal_dispatch = attn_mask_type in ( + "causal", + "causal_bottom_right", + ) or window_size == (-1, 0) + self.logger.info("Running with MLATriton backend (NVTE_MLA_TRITON=1)") + return self.mla_triton_attention( + query_layer, + key_layer, + value_layer, + qkv_format=qkv_format, + is_causal=is_causal_dispatch, + ) + # run attention softmax_offset = ( self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32) diff --git a/transformer_engine/pytorch/triton/__init__.py b/transformer_engine/pytorch/triton/__init__.py index 6d3141253d..8ec7af220b 100644 --- a/transformer_engine/pytorch/triton/__init__.py +++ b/transformer_engine/pytorch/triton/__init__.py @@ -4,3 +4,4 @@ """PyTorch wrappers for Triton kernels.""" from transformer_engine.pytorch.triton import mhc +from transformer_engine.pytorch.triton import mla diff --git a/transformer_engine/pytorch/triton/mla.py b/transformer_engine/pytorch/triton/mla.py new file mode 100644 index 0000000000..df73f06e6b --- /dev/null +++ b/transformer_engine/pytorch/triton/mla.py @@ -0,0 +1,575 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""PyTorch wrapper for the Triton MLA attention kernels. + +Public entrypoint: :func:`mla_attention`. Forward and backward both run on +Triton kernels (FA-2 style). The pure-PyTorch :func:`mla_attention_ref` is +kept as the test reference. + +The kernels do NOT use ``tl.atomic_add`` (forward owns one O slice per program; +backward uses two passes — dQ owned by Q-tile programs, dK/dV owned by K-tile +programs). Results are deterministic, so no ``NVTE_ALLOW_NONDETERMINISTIC_ALGO`` +gate is needed. +""" + +from typing import Optional + +import torch +import triton + +from transformer_engine.common.triton.mla import ( + _mla_attn_fwd, + _mla_attn_bwd_preprocess, + _mla_attn_bwd_dq, + _mla_attn_bwd_dkv, + _mla_decode_attn_fwd, +) + + +_SUPPORTED_QKV_FORMATS = ("bshd", "bhsd", "sbhd") + + +def _user_to_bhsd(t: torch.Tensor, qkv_format: str) -> torch.Tensor: + """User-layout tensor -> contiguous BHSD.""" + if qkv_format == "bhsd": + return t.contiguous() if not t.is_contiguous() else t + if qkv_format == "bshd": + # [B, S, H, D] -> [B, H, S, D] + return t.transpose(1, 2).contiguous() + if qkv_format == "sbhd": + # [S, B, H, D] -> [B, H, S, D] + return t.permute(1, 2, 0, 3).contiguous() + raise ValueError( + f"Unsupported qkv_format: {qkv_format!r}. Expected one of {_SUPPORTED_QKV_FORMATS}." + ) + + +def _bhsd_to_user(t: torch.Tensor, qkv_format: str) -> torch.Tensor: + """BHSD tensor -> contiguous user layout.""" + if qkv_format == "bhsd": + return t + if qkv_format == "bshd": + # [B, H, S, D] -> [B, S, H, D] + return t.transpose(1, 2).contiguous() + if qkv_format == "sbhd": + # [B, H, S, D] -> [S, B, H, D] + return t.permute(2, 0, 1, 3).contiguous() + raise ValueError( + f"Unsupported qkv_format: {qkv_format!r}. Expected one of {_SUPPORTED_QKV_FORMATS}." + ) + + +def mla_attention_ref( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + is_causal: bool = False, +) -> torch.Tensor: + """Pure-PyTorch reference MLA attention in BHSD layout. + + Supports ``head_dim_qk != head_dim_v``. Internal compute is fp32; output is + cast back to the input dtype. Right-aligned causal mask matches the kernel: + row ``i`` attends to keys ``j <= i + (S_kv - S_q)``. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + in_dtype = q.dtype + q32 = q.float() + k32 = k.float() + v32 = v.float() + + s = torch.matmul(q32, k32.transpose(-1, -2)) * softmax_scale + if is_causal: + s_q, s_kv = s.shape[-2], s.shape[-1] + row = torch.arange(s_q, device=s.device).unsqueeze(-1) + col = torch.arange(s_kv, device=s.device).unsqueeze(0) + mask = col > (row + (s_kv - s_q)) + s = s.masked_fill(mask, float("-inf")) + p = torch.softmax(s, dim=-1) + out = torch.matmul(p, v32) + return out.to(in_dtype) + + +def _launch_mla_fwd(q_bhsd, k_bhsd, v_bhsd, softmax_scale, is_causal): + B, H, S_q, D_qk = q_bhsd.shape + S_kv = k_bhsd.shape[2] + D_v = v_bhsd.shape[3] + + o = torch.empty((B, H, S_q, D_v), device=q_bhsd.device, dtype=q_bhsd.dtype) + lse = torch.empty((B, H, S_q), device=q_bhsd.device, dtype=torch.float32) + + block_dmodel_qk = max(16, triton.next_power_of_2(D_qk)) + block_dmodel_v = max(16, triton.next_power_of_2(D_v)) + + grid = lambda meta: (triton.cdiv(S_q, meta["BLOCK_M"]), B * H) + + _mla_attn_fwd[grid]( + q_bhsd, + k_bhsd, + v_bhsd, + o, + lse, + float(softmax_scale), + B, + H, + S_q, + S_kv, + D_qk, + D_v, + # Q strides + q_bhsd.stride(0), + q_bhsd.stride(1), + q_bhsd.stride(2), + 1, + # K strides + k_bhsd.stride(0), + k_bhsd.stride(1), + k_bhsd.stride(2), + 1, + # V strides + v_bhsd.stride(0), + v_bhsd.stride(1), + v_bhsd.stride(2), + 1, + # O strides + o.stride(0), + o.stride(1), + o.stride(2), + 1, + # LSE strides + lse.stride(0), + lse.stride(1), + 1, + IS_CAUSAL=is_causal, + BLOCK_DMODEL_QK=block_dmodel_qk, + BLOCK_DMODEL_V=block_dmodel_v, + ) + return o, lse + + +# Preprocess kernel uses a fixed BLOCK_M (it's I/O-bound and the choice barely +# matters); the dQ / dKV kernels autotune and pick their own BLOCK_M / BLOCK_N. +_BWD_PREPROCESS_BLOCK_M = 128 + + +def _launch_mla_bwd(q_bhsd, k_bhsd, v_bhsd, o_bhsd, lse, do_bhsd, softmax_scale, is_causal): + """Run the three backward kernels and return ``(dQ, dK, dV)`` in BHSD.""" + B, H, S_q, D_qk = q_bhsd.shape + S_kv = k_bhsd.shape[2] + D_v = v_bhsd.shape[3] + + block_dmodel_qk = max(16, triton.next_power_of_2(D_qk)) + block_dmodel_v = max(16, triton.next_power_of_2(D_v)) + + delta = torch.empty((B, H, S_q), device=q_bhsd.device, dtype=torch.float32) + dq = torch.empty_like(q_bhsd) + dk = torch.empty_like(k_bhsd) + dv = torch.empty_like(v_bhsd) + + # 1) Delta = rowsum(O * dO) + grid_pre = (triton.cdiv(S_q, _BWD_PREPROCESS_BLOCK_M), B * H) + _mla_attn_bwd_preprocess[grid_pre]( + o_bhsd, + do_bhsd, + delta, + B, + H, + S_q, + D_v, + # O strides + o_bhsd.stride(0), + o_bhsd.stride(1), + o_bhsd.stride(2), + 1, + # dO strides + do_bhsd.stride(0), + do_bhsd.stride(1), + do_bhsd.stride(2), + 1, + # Delta strides + delta.stride(0), + delta.stride(1), + 1, + BLOCK_M=_BWD_PREPROCESS_BLOCK_M, + BLOCK_DMODEL_V=block_dmodel_v, + ) + + common_args = ( + q_bhsd, + k_bhsd, + v_bhsd, + do_bhsd, + lse, + delta, + ) + common_shape_args = (B, H, S_q, S_kv, D_qk, D_v) + common_strides = ( + # Q + q_bhsd.stride(0), + q_bhsd.stride(1), + q_bhsd.stride(2), + 1, + # K + k_bhsd.stride(0), + k_bhsd.stride(1), + k_bhsd.stride(2), + 1, + # V + v_bhsd.stride(0), + v_bhsd.stride(1), + v_bhsd.stride(2), + 1, + # dO + do_bhsd.stride(0), + do_bhsd.stride(1), + do_bhsd.stride(2), + 1, + # LSE + lse.stride(0), + lse.stride(1), + 1, + # Delta + delta.stride(0), + delta.stride(1), + 1, + ) + + # 2) dQ + grid_dq = lambda meta: (triton.cdiv(S_q, meta["BLOCK_M"]), B * H) + _mla_attn_bwd_dq[grid_dq]( + *common_args, + dq, + float(softmax_scale), + *common_shape_args, + *common_strides, + # dQ strides + dq.stride(0), + dq.stride(1), + dq.stride(2), + 1, + IS_CAUSAL=is_causal, + BLOCK_DMODEL_QK=block_dmodel_qk, + BLOCK_DMODEL_V=block_dmodel_v, + ) + + # 3) dK, dV + grid_dkv = lambda meta: (triton.cdiv(S_kv, meta["BLOCK_N"]), B * H) + _mla_attn_bwd_dkv[grid_dkv]( + *common_args, + dk, + dv, + float(softmax_scale), + *common_shape_args, + *common_strides, + # dK strides + dk.stride(0), + dk.stride(1), + dk.stride(2), + 1, + # dV strides + dv.stride(0), + dv.stride(1), + dv.stride(2), + 1, + IS_CAUSAL=is_causal, + BLOCK_DMODEL_QK=block_dmodel_qk, + BLOCK_DMODEL_V=block_dmodel_v, + ) + + return dq, dk, dv + + +class MLAttentionFn(torch.autograd.Function): + """Forward and backward via Triton kernels (FA-2 style).""" + + @staticmethod + def forward(ctx, q, k, v, softmax_scale, is_causal, qkv_format): + if qkv_format not in _SUPPORTED_QKV_FORMATS: + raise ValueError( + f"qkv_format must be one of {_SUPPORTED_QKV_FORMATS}, got {qkv_format!r}" + ) + if q.dtype not in (torch.float16, torch.bfloat16): + raise ValueError(f"mla_attention requires fp16 or bf16 inputs, got {q.dtype}") + if not (q.dtype == k.dtype == v.dtype): + raise ValueError("q, k, v must share the same dtype") + if not q.is_cuda: + raise ValueError("mla_attention requires CUDA tensors") + + q_bhsd = _user_to_bhsd(q, qkv_format) + k_bhsd = _user_to_bhsd(k, qkv_format) + v_bhsd = _user_to_bhsd(v, qkv_format) + + if q_bhsd.shape[3] != k_bhsd.shape[3]: + raise ValueError( + f"q.head_dim and k.head_dim must match (got {q_bhsd.shape[3]} vs {k_bhsd.shape[3]})" + ) + if q_bhsd.shape[:2] != k_bhsd.shape[:2] or q_bhsd.shape[:2] != v_bhsd.shape[:2]: + raise ValueError("q, k, v must share batch and head dimensions") + if k_bhsd.shape[2] != v_bhsd.shape[2]: + raise ValueError("k.seq_len and v.seq_len must match") + + o_bhsd, lse = _launch_mla_fwd(q_bhsd, k_bhsd, v_bhsd, softmax_scale, is_causal) + + ctx.save_for_backward(q_bhsd, k_bhsd, v_bhsd, o_bhsd, lse) + ctx.softmax_scale = softmax_scale + ctx.is_causal = is_causal + ctx.qkv_format = qkv_format + + return _bhsd_to_user(o_bhsd, qkv_format) + + @staticmethod + def backward(ctx, grad_o): + q_bhsd, k_bhsd, v_bhsd, o_bhsd, lse = ctx.saved_tensors + grad_o_bhsd = _user_to_bhsd(grad_o, ctx.qkv_format) + + dq_bhsd, dk_bhsd, dv_bhsd = _launch_mla_bwd( + q_bhsd, + k_bhsd, + v_bhsd, + o_bhsd, + lse, + grad_o_bhsd, + ctx.softmax_scale, + ctx.is_causal, + ) + + grad_q = _bhsd_to_user(dq_bhsd, ctx.qkv_format) + grad_k = _bhsd_to_user(dk_bhsd, ctx.qkv_format) + grad_v = _bhsd_to_user(dv_bhsd, ctx.qkv_format) + return grad_q, grad_k, grad_v, None, None, None + + +def mla_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + softmax_scale: Optional[float] = None, + is_causal: bool = False, + qkv_format: str = "bshd", +) -> torch.Tensor: + """Triton MLA attention (FA-2 style forward and backward) for SM80+. + + Supports ``head_dim_qk != head_dim_v`` (DeepSeek-V2/V3 style). Both forward + and backward run as Triton kernels: the forward saves an fp32 ``LSE``, which + the backward consumes to recompute softmax probabilities without extra + memory. + + Parameters + ---------- + q, k, v + Tensors in ``qkv_format`` layout. ``q`` and ``k`` must share the last + dim (head_dim_qk); ``v`` may have a different last dim (head_dim_v). + Dtype must be fp16 or bf16. Batch and head dims must match across the + three tensors; ``k`` and ``v`` must share seq_len. + softmax_scale + Multiplier applied to ``Q @ K^T`` before softmax. Defaults to + ``1 / sqrt(head_dim_qk)``. + is_causal + If True, applies a right-aligned causal mask: row ``i`` attends to + keys ``j <= i + (S_kv - S_q)``. With ``S_q == S_kv`` this is the + standard causal self-attention mask. + qkv_format + ``"bshd"`` (default, matches FlashAttention), ``"bhsd"`` (matches + ``F.scaled_dot_product_attention``), or ``"sbhd"`` (megatron-style). + + Returns + ------- + torch.Tensor + Attention output in ``qkv_format`` layout, dtype matching the inputs. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + return MLAttentionFn.apply(q, k, v, softmax_scale, is_causal, qkv_format) + + +# --------------------------------------------------------------------------- +# Decode (absorbed up-projection over compressed KV cache) +# --------------------------------------------------------------------------- + + +def mla_decode_attention_ref( + q_nope_abs: torch.Tensor, + q_rope: torch.Tensor, + c_kv: torch.Tensor, + k_rope: torch.Tensor, + softmax_scale: float, + is_causal: bool = False, +) -> torch.Tensor: + """Pure-PyTorch reference for absorbed MLA decode attention. + + Inputs (BHSD layout for Q-side, BSR layout for compressed cache): + + - ``q_nope_abs`` ``[B, H, S_q, R]`` — Q's nope side already multiplied by + ``W_uk^T``. ``R`` is ``kv_lora_rank``. + - ``q_rope`` ``[B, H, S_q, R_rope]`` — Q's rope side. + - ``c_kv`` ``[B, S_kv, R]`` — compressed KV cache, shared across heads. + - ``k_rope`` ``[B, S_kv, R_rope]`` — decoupled rope keys, shared across heads. + - ``softmax_scale`` — typically ``1 / sqrt(head_dim_qk_orig)`` where + ``head_dim_qk_orig = q_nope_dim + R_rope`` (the *un-absorbed* head dim). + + Returns ``o_inter`` ``[B, H, S_q, R]``. The caller multiplies by ``W_uv`` + (per head) to obtain the final attention output. + """ + in_dtype = q_nope_abs.dtype + qn = q_nope_abs.float() + qr = q_rope.float() + ck = c_kv.float() + kr = k_rope.float() + + score_nope = torch.einsum("bhmr,bnr->bhmn", qn, ck) + score_rope = torch.einsum("bhmr,bnr->bhmn", qr, kr) + s = (score_nope + score_rope) * softmax_scale + + if is_causal: + s_q, s_kv = s.shape[-2], s.shape[-1] + row = torch.arange(s_q, device=s.device).unsqueeze(-1) + col = torch.arange(s_kv, device=s.device).unsqueeze(0) + mask = col > (row + (s_kv - s_q)) + s = s.masked_fill(mask, float("-inf")) + + p = torch.softmax(s, dim=-1) + o_inter = torch.einsum("bhmn,bnr->bhmr", p, ck) + return o_inter.to(in_dtype) + + +def _launch_mla_decode_fwd(q_nope_abs, q_rope, c_kv, k_rope, softmax_scale, is_causal): + B, H, S_q, R = q_nope_abs.shape + S_kv = c_kv.shape[1] + R_rope = q_rope.shape[3] + + o_inter = torch.empty((B, H, S_q, R), device=q_nope_abs.device, dtype=q_nope_abs.dtype) + + block_dmodel_r = max(16, triton.next_power_of_2(R)) + block_dmodel_rr = max(16, triton.next_power_of_2(R_rope)) + + grid = lambda meta: (triton.cdiv(S_q, meta["BLOCK_M"]), B * H) + + _mla_decode_attn_fwd[grid]( + q_nope_abs, + q_rope, + c_kv, + k_rope, + o_inter, + float(softmax_scale), + B, + H, + S_q, + S_kv, + R, + R_rope, + # Q_nope_abs strides + q_nope_abs.stride(0), + q_nope_abs.stride(1), + q_nope_abs.stride(2), + 1, + # Q_rope strides + q_rope.stride(0), + q_rope.stride(1), + q_rope.stride(2), + 1, + # c_kv strides (no H) + c_kv.stride(0), + c_kv.stride(1), + 1, + # k_rope strides (no H) + k_rope.stride(0), + k_rope.stride(1), + 1, + # O_inter strides + o_inter.stride(0), + o_inter.stride(1), + o_inter.stride(2), + 1, + IS_CAUSAL=is_causal, + BLOCK_DMODEL_R=block_dmodel_r, + BLOCK_DMODEL_RR=block_dmodel_rr, + ) + return o_inter + + +def mla_decode_attention( + q_nope_abs: torch.Tensor, + q_rope: torch.Tensor, + c_kv: torch.Tensor, + k_rope: torch.Tensor, + *, + softmax_scale: float, + is_causal: bool = False, +) -> torch.Tensor: + """Triton MLA decode forward over a compressed KV cache (SM80+). + + Implements the FlashMLA-style absorbed up-projection: ``c_kv`` plays the + role of both K (nope side) and V, so per-head K/V are never materialized + inside the kernel. The caller is expected to apply ``W_uv`` to the + returned ``O_inter`` to obtain the final attention output. + + Parameters + ---------- + q_nope_abs : ``[B, H, S_q, R]`` + Q's nope side after absorbing ``W_uk^T`` (so its last-dim is + ``kv_lora_rank``, not the original Q nope dim). + q_rope : ``[B, H, S_q, R_rope]`` + Q's rope side. + c_kv : ``[B, S_kv, R]`` + Compressed KV cache, shared across heads. + k_rope : ``[B, S_kv, R_rope]`` + Decoupled rope keys, shared across heads. + softmax_scale : float + Required. Typically ``1 / sqrt(q_nope_dim + R_rope)`` — i.e. the + original head_dim_qk before absorption. **No default**: the absorbed + Q's last dim (``R``) is *not* the right denominator, so the kernel + will not guess. + is_causal : bool + Right-aligned causal: row ``i`` attends to keys ``j <= i + (S_kv - S_q)``. + + Returns + ------- + torch.Tensor + ``O_inter`` of shape ``[B, H, S_q, R]`` in the input dtype. Apply + ``W_uv`` outside the kernel for the final attention output. + """ + if q_nope_abs.dtype not in (torch.float16, torch.bfloat16): + raise ValueError( + f"mla_decode_attention requires fp16 or bf16 inputs, got {q_nope_abs.dtype}" + ) + if not (q_nope_abs.dtype == q_rope.dtype == c_kv.dtype == k_rope.dtype): + raise ValueError("q_nope_abs, q_rope, c_kv, k_rope must share dtype") + if not q_nope_abs.is_cuda: + raise ValueError("mla_decode_attention requires CUDA tensors") + # Forward-only in v1: the kernel is launched directly (no autograd.Function + # wrapper), so a tensor that needs gradients would silently lose them at + # the kernel boundary. Refuse loudly instead. + if any(t.requires_grad for t in (q_nope_abs, q_rope, c_kv, k_rope)): + raise NotImplementedError( + "mla_decode_attention is forward-only in v1; backward through the" + " absorbed-projection decode kernel is not implemented. Detach" + " inputs (or use torch.no_grad()) before calling." + ) + + if q_nope_abs.shape[-1] != c_kv.shape[-1]: + raise ValueError( + "q_nope_abs.last_dim and c_kv.last_dim must match (kv_lora_rank);" + f" got {q_nope_abs.shape[-1]} vs {c_kv.shape[-1]}" + ) + if q_rope.shape[-1] != k_rope.shape[-1]: + raise ValueError( + "q_rope.last_dim and k_rope.last_dim must match (rope_dim);" + f" got {q_rope.shape[-1]} vs {k_rope.shape[-1]}" + ) + if q_nope_abs.shape[:3] != q_rope.shape[:3]: + raise ValueError("q_nope_abs and q_rope must share (B, H, S_q)") + if c_kv.shape[:2] != k_rope.shape[:2]: + raise ValueError("c_kv and k_rope must share (B, S_kv)") + if q_nope_abs.shape[0] != c_kv.shape[0]: + raise ValueError("Q-side and KV-cache batch dims must match") + + qn = q_nope_abs.contiguous() if not q_nope_abs.is_contiguous() else q_nope_abs + qr = q_rope.contiguous() if not q_rope.is_contiguous() else q_rope + ck = c_kv.contiguous() if not c_kv.is_contiguous() else c_kv + kr = k_rope.contiguous() if not k_rope.is_contiguous() else k_rope + + return _launch_mla_decode_fwd(qn, qr, ck, kr, softmax_scale, is_causal)