diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index 7a0a75283c3..27d81d867c9 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -715,7 +715,8 @@ inline const char *getEnvVar(const char *varName) { inline bool checkAttentionBackend() { const char *backend = getEnvVar("FD_ATTENTION_BACKEND"); if (backend && (std::strcmp(backend, "MLA_ATTN") == 0 || - std::strcmp(backend, "DSA_ATTN") == 0)) { + std::strcmp(backend, "DSA_ATTN") == 0 || + std::strcmp(backend, "TRITON_MLA_ATTN") == 0)) { return true; } return false; diff --git a/fastdeploy/config.py b/fastdeploy/config.py index ad02ba8d333..fd2f29ad745 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1618,7 +1618,7 @@ def __init__(self, args): self.write_policy = "write_through_selective" self.write_through_threshold = 2 self.num_cpu_blocks = None - self.use_mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN" + self.use_mla_cache = envs.FD_ATTENTION_BACKEND in ("MLA_ATTN", "TRITON_MLA_ATTN") for key, value in args.items(): if hasattr(self, key): diff --git a/fastdeploy/model_executor/layers/attention/__init__.py b/fastdeploy/model_executor/layers/attention/__init__.py index 7efc3259fbc..bd214c0b370 100644 --- a/fastdeploy/model_executor/layers/attention/__init__.py +++ b/fastdeploy/model_executor/layers/attention/__init__.py @@ -23,6 +23,7 @@ from .mla_attention_backend import MLAAttentionBackend from .moba_attention_backend import PlasAttentionBackend from .native_paddle_backend import PaddleNativeAttnBackend +from .triton_mla_attention_backend import TritonMLAAttentionBackend __all__ = [ "AttentionBackend", @@ -36,4 +37,5 @@ "Attention", "PlasAttentionBackend", "FlashMaskAttentionBackend", + "TritonMLAAttentionBackend", ] diff --git a/fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py new file mode 100644 index 00000000000..726f11f0995 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py @@ -0,0 +1,368 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Triton-based MLA Attention Backend for FastDeploy. +# Uses triton kernels for KV cache write and decode attention, +# and flash_attn_unpadded for extend (prefill) attention. +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Tuple + +import paddle +from paddle.nn.functional.flash_attention import flash_attn_unpadded +from paddleformers.utils.log import logger + +try: + from paddle.nn.functional.flash_attention import flash_attention_v3_varlen +except Exception as e: + logger.debug(f"flash_attention_v3_varlen not available: {e}") + flash_attention_v3_varlen = None + +from fastdeploy.model_executor.layers.attention.ops import ( + get_block_shape_and_split_kv_block, +) +from fastdeploy.model_executor.layers.attention.triton_ops.decode_attention import ( + compute_num_kv_splits, + decode_attention_fwd, +) +from fastdeploy.model_executor.layers.attention.triton_ops.mla_cache_kernel import ( + mla_write_cache_triton, +) +from fastdeploy.model_executor.layers.attention.triton_ops.unified_extend_attention import ( + build_kv_indices_from_block_tables, +) + +if TYPE_CHECKING: + from fastdeploy.model_executor.forward_meta import ForwardMeta + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.layers.attention.attention import Attention +from fastdeploy.model_executor.layers.attention.base_attention_backend import ( + AttentionBackend, + AttentionMetadata, +) +from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +@dataclass +class TritonMLAAttentionMetadata(AttentionMetadata): + _dtype: str = "bfloat16" + block_tables: Optional[paddle.Tensor] = None + max_enc_len_this_time: Optional[paddle.Tensor] = None + max_dec_len_this_time: Optional[paddle.Tensor] = None + max_kv_len_this_time: Optional[paddle.Tensor] = None + max_seqlen_k: int = 0 + # Pre-computed decode indices (CUDAGraph compatible) + kv_indptr: Optional[paddle.Tensor] = None + kv_indices: Optional[paddle.Tensor] = None + num_kv_splits: Optional[paddle.Tensor] = None + decode_bs: int = 0 + + +class TritonMLAAttentionBackend(AttentionBackend): + """ + Triton-based MLA Attention Backend. + Uses triton kernels for KV cache write and decode attention. + """ + + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: TritonMLAAttentionMetadata + flash_attn_func: callable = None + + def __init__( + self, + fd_config: FDConfig, + kv_num_heads: int, + num_heads: int, + head_dim: int, + encoder_block_shape_q: int = -1, + decoder_block_shape_q: int = -1, + ) -> None: + super().__init__() + self.attention_metadata: TritonMLAAttentionMetadata = None + + self.block_size: int = fd_config.cache_config.block_size + self.max_seq_len: int = fd_config.model_config.max_model_len + self.causal: bool = getattr(fd_config.model_config, "causal", True) + + self.num_heads: int = num_heads + self.head_dim: int = head_dim + self.num_layers: int = fd_config.model_config.num_hidden_layers + + self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank + self.qk_rope_head_dim: int = fd_config.model_config.qk_rope_head_dim + self.qk_head_dim: int = fd_config.model_config.qk_nope_head_dim + fd_config.model_config.qk_rope_head_dim + self.attn_softmax_scale: float = self.qk_head_dim**-0.5 + self.rope_scaling = getattr(fd_config.model_config, "rope_scaling", None) + if self.rope_scaling and "factor" in self.rope_scaling: + mscale_all_dim = fd_config.model_config.rope_scaling.get("mscale_all_dim", False) + scaling_factor = fd_config.model_config.rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale + + self.max_kv_splits: int = 32 + + self.rank, self.device_id = init_rank_and_device_id(fd_config) + self.useless_tensor = paddle.zeros([1], dtype="int32") + + # Pre-allocate buffers for CUDAGraph compatibility (stable memory addresses) + self.max_num_seqs = fd_config.scheduler_config.max_num_seqs + max_blocks_per_seq = fd_config.cache_config.max_block_num_per_seq + self._kv_indptr_buf = paddle.zeros([self.max_num_seqs + 1], dtype="int32") + self._kv_indices_buf = paddle.zeros([self.max_num_seqs * max_blocks_per_seq * self.block_size], dtype="int32") + self._num_kv_splits_buf = paddle.ones([self.max_num_seqs], dtype="int32") + + # Pre-allocate decode kernel intermediate buffers for CUDAGraph address stability + Lv = fd_config.model_config.kv_lora_rank + self._attn_logits_buf = paddle.empty([self.max_num_seqs, num_heads, self.max_kv_splits, Lv], dtype="float32") + self._attn_lse_buf = paddle.empty([self.max_num_seqs, num_heads, self.max_kv_splits], dtype="float32") + self._o_buf = paddle.empty([self.max_num_seqs, num_heads, Lv], dtype=paddle.get_default_dtype()) + + if self.flash_attn_func is None: + prop = paddle.device.cuda.get_device_properties() + cc = prop.major * 10 + prop.minor + is_current_sm_supported = cc >= 90 + is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs()) + if is_current_sm_supported and is_paddle_supported: + self.flash_attn_func = flash_attention_v3_varlen + logger.info("TritonMLAAttentionBackend: Using Flash Attention V3.") + self.flash_attn_kwargs = {"softmax_scale": self.attn_softmax_scale} + else: + self.flash_attn_func = flash_attn_unpadded + logger.info("TritonMLAAttentionBackend: Using Flash Attention V2.") + self.flash_attn_kwargs = {"scale": self.attn_softmax_scale, "training": False} + + def init_attention_metadata(self, forward_meta: ForwardMeta): + metadata = TritonMLAAttentionMetadata() + metadata._dtype = paddle.get_default_dtype() + metadata.block_tables = forward_meta.block_tables + + get_block_shape_and_split_kv_block( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, + self.useless_tensor, + forward_meta.decoder_num_blocks_device, + forward_meta.decoder_chunk_size_device, + forward_meta.max_len_tensor_cpu, + self.useless_tensor, + self.useless_tensor, + self.useless_tensor, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + -1, + -1, + -1, + self.block_size, + ) + metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1] + metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2] + metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[5] + # max_seqlen_k must include cached tokens for chunked prefill / prefix caching. + metadata.max_seqlen_k = max( + int(metadata.max_kv_len_this_time.item()), + int(metadata.max_enc_len_this_time.item()), + ) + + # Pre-compute decode kv_indptr/kv_indices into stable pre-allocated buffers. + # CUDAGraph requires tensors at the same memory address between capture and replay. + # + # IMPORTANT: q for decode is built by extract_decoder_token_from_q which produces + # a tensor of shape [max_num_seqs, hidden_dim] indexed by ORIGINAL batch_id (with + # garbage for non-decode batches). Therefore kv_indptr/kv_indices MUST follow the + # same full max_num_seqs layout (not mask-compressed) so that batch i in q is + # paired with batch i's KV cache indices. For non-decode batches we leave a + # zero-length range so the kernel produces a zeroed/no-op output for those slots. + seq_lens_decoder = forward_meta.seq_lens_decoder + seq_lens_this_time = forward_meta.seq_lens_this_time + max_num_seqs = seq_lens_decoder.shape[0] + decode_mask = seq_lens_decoder > 0 + decode_bs = int(decode_mask.sum().item()) + metadata.decode_bs = decode_bs + + if decode_bs > 0: + # Full-layout per-batch KV lengths: real length for decode batches, 0 otherwise. + full_seq_lens = paddle.where( + decode_mask, + seq_lens_decoder + seq_lens_this_time, + paddle.zeros_like(seq_lens_decoder), + ).cast("int32") + total_kv_len = int(paddle.sum(full_seq_lens).item()) + + build_kv_indices_from_block_tables( + forward_meta.block_tables, + full_seq_lens, + self.block_size, + max_num_seqs, + total_kv_len=total_kv_len, + kv_indptr_buf=self._kv_indptr_buf, + kv_indices_buf=self._kv_indices_buf, + ) + + # num_kv_splits in full layout: real value for decode batches, 1 otherwise + # (must be >= 1 to avoid division by zero in stage1 kernel). + compute_num_kv_splits(full_seq_lens, max_num_seqs, self.max_kv_splits, out_buf=self._num_kv_splits_buf) + # Force padded/non-decode entries to 1 (compute_num_kv_splits may return 0 for len=0) + self._num_kv_splits_buf[:] = paddle.where( + self._num_kv_splits_buf > 0, + self._num_kv_splits_buf, + paddle.ones_like(self._num_kv_splits_buf), + ) + else: + # No decode sequences: fill buffers with safe defaults + self._kv_indptr_buf[:] = 0 + self._num_kv_splits_buf[:] = 1 + + # Always use the full pre-allocated buffers (stable memory for CUDAGraph) + metadata.kv_indptr = self._kv_indptr_buf + metadata.kv_indices = self._kv_indices_buf + metadata.num_kv_splits = self._num_kv_splits_buf + + self.attention_metadata = metadata + + def get_attention_meta(self) -> AttentionMetadata: + return self.attention_metadata + + def get_kv_cache_shape( + self, + max_num_blocks: int, + kv_cache_quant_type: str = None, + ) -> Tuple[int, int, int, int]: + key_cache_shape = [max_num_blocks, 1, self.block_size, self.kv_lora_rank + self.qk_rope_head_dim] + value_cache_shape = [] + return key_cache_shape, value_cache_shape + + def forward_extend( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + metadata = self.attention_metadata + latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None + + if latent_cache is not None and forward_meta.slot_mapping is not None: + mla_write_cache_triton(compressed_kv, k_pe, latent_cache, forward_meta.slot_mapping) + + fmha_out = self.flash_attn_func( + q, + k, + v, + forward_meta.cu_seqlens_q, + forward_meta.cu_seqlens_k, + metadata.max_enc_len_this_time, + metadata.max_seqlen_k, + causal=self.causal, + **self.flash_attn_kwargs, + )[0] + + return fmha_out + + def forward_decode( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + metadata = self.attention_metadata + latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None + + if latent_cache is not None and forward_meta.slot_mapping is not None: + mla_write_cache_triton(compressed_kv, k_pe, latent_cache, forward_meta.slot_mapping) + + return self._run_decode_kernel(q, latent_cache, metadata) + + def forward_mixed( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + if k is not None: + return self.forward_extend(q, k, v, qkv, compressed_kv, k_pe, layer, forward_meta) + + # Decode branch within mixed mode. q is in full max_num_seqs layout (built by + # extract_decoder_token_from_q upstream). kv_indptr/kv_indices are built in the + # same full layout, so non-decode batches naturally get 0-length KV ranges and + # produce no-op output, which insert_decoder_result_back then ignores. + metadata = self.attention_metadata + latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None + + if latent_cache is not None and forward_meta.slot_mapping is not None: + mla_write_cache_triton(compressed_kv, k_pe, latent_cache, forward_meta.slot_mapping) + + return self._run_decode_kernel(q, latent_cache, metadata) + + def _run_decode_kernel( + self, + q: paddle.Tensor, + latent_cache: paddle.Tensor, + metadata: TritonMLAAttentionMetadata, + ) -> paddle.Tensor: + """Run triton decode attention kernel. q must have shape [bs, num_heads * latent_dim].""" + bs = q.shape[0] + Lv = self.kv_lora_rank + latent_dim = self.kv_lora_rank + self.qk_rope_head_dim + q_reshaped = q.reshape([bs, self.num_heads, latent_dim]) + + # Use pre-allocated buffers sliced to current batch size for CUDAGraph address stability + attn_logits = self._attn_logits_buf[:bs] + attn_lse = self._attn_lse_buf[:bs] + o = self._o_buf[:bs] + + decode_attention_fwd( + q_reshaped, + latent_cache, + latent_cache[:, :, :, : self.kv_lora_rank], + o, + metadata.kv_indptr, + metadata.kv_indices, + attn_logits, + attn_lse, + metadata.num_kv_splits, + self.max_kv_splits, + self.attn_softmax_scale, + self.block_size, + ) + + return o.reshape([-1, self.num_heads * Lv]) diff --git a/fastdeploy/model_executor/layers/attention/triton_ops/__init__.py b/fastdeploy/model_executor/layers/attention/triton_ops/__init__.py index 860bc3beadd..d60e91d2921 100644 --- a/fastdeploy/model_executor/layers/attention/triton_ops/__init__.py +++ b/fastdeploy/model_executor/layers/attention/triton_ops/__init__.py @@ -17,3 +17,12 @@ # https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/triton_ops/extend_attention.py # Licensed under Apache License 2.0 """ + +from .decode_attention import compute_num_kv_splits, decode_attention_fwd # noqa: F401 +from .mla_cache_kernel import mla_write_cache_triton # noqa: F401 +from .unified_extend_attention import ( # noqa: F401 + build_kv_indices_from_block_tables, + build_unified_kv_indices, + extend_attention_fwd_unified, + triton_cumsum_with_zero_prefix, +) diff --git a/fastdeploy/model_executor/layers/attention/triton_ops/decode_attention.py b/fastdeploy/model_executor/layers/attention/triton_ops/decode_attention.py new file mode 100644 index 00000000000..2beb9a6ef1d --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/triton_ops/decode_attention.py @@ -0,0 +1,476 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from +# https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +# Licensed under Apache License 2.0 +# +# Memory-efficient split-KV attention for decoding, adapted for paged KV cache. +""" + +import paddle +import triton +import triton.language as tl + +from fastdeploy.model_executor.ops.triton_ops.triton_utils import ( + enable_compat_on_triton_kernel, +) + +_MIN_BLOCK_KV = 32 + + +@enable_compat_on_triton_kernel +@triton.jit +def _fwd_grouped_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + kv_indptr, + kv_indices, + Att_Out, + Att_Lse, + num_kv_splits, + stride_qbs, + stride_qh, + stride_buf_kb, + stride_buf_kh, + stride_buf_kt, + stride_buf_vb, + stride_buf_vh, + stride_buf_vt, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + KV_BLOCK_SIZE: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + """ + Split-KV decode attention stage 1 for grouped query attention on paged cache. + Each program handles (batch, head_group, kv_split). + """ + cur_batch = tl.program_id(0) + cur_head_id = tl.program_id(1) + cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) + split_kv_id = tl.program_id(2) + + if BLOCK_H < kv_group_num: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H + mask_h = mask_h & (cur_head < q_head_num) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + kv_splits = tl.load(num_kv_splits + cur_batch) + + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < Lk + off_qpe = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] + + kv_len_per_split = tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) + if BLOCK_DPE > 0: + qpe = tl.load(Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0) + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_loc = tl.load( + kv_indices + cur_batch_kv_start_idx + offs_n, + mask=offs_n < split_kv_end, + other=0, + ) + # Decompose flat index into (block_id, offset_in_block) for paged cache + kv_block_id = kv_loc // KV_BLOCK_SIZE + kv_offset = kv_loc % KV_BLOCK_SIZE + + # Load K: cache shape [num_blocks, kv_heads, block_size, head_dim] + offs_buf_k = ( + kv_block_id[None, :] * stride_buf_kb + + cur_kv_head * stride_buf_kh + + kv_offset[None, :] * stride_buf_kt + + offs_d[:, None] + ) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), + other=0.0, + ) + qk = tl.dot(q, k.to(q.dtype)) + if BLOCK_DPE > 0: + offs_buf_kpe = ( + kv_block_id[None, :] * stride_buf_kb + + cur_kv_head * stride_buf_kh + + kv_offset[None, :] * stride_buf_kt + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + qk *= sm_scale + + qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")) + + # Load V from paged cache + offs_buf_v = ( + kv_block_id[:, None] * stride_buf_vb + + cur_kv_head * stride_buf_vh + + kv_offset[:, None] * stride_buf_vt + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv[None, :] + ) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum[:, None], + mask=(mask_h[:, None]) & (mask_dv[None, :]), + ) + + offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os) // Lv + + tl.store( + Att_Lse + offs_mid_o_1, + e_max + tl.log(e_sum), + mask=mask_h, + ) + + +@enable_compat_on_triton_kernel +@triton.jit +def _fwd_kernel_stage2( + Mid_O, + Mid_O_1, + O, + kv_indptr, + num_kv_splits, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + MAX_KV_SPLITS: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + """ + Stage 2: reduce across KV splits to produce final output. + """ + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch) + kv_splits = tl.load(num_kv_splits + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv + kv_len_per_split = tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + + for split_kv_id in range(0, MAX_KV_SPLITS): + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0) + tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + # Guard against e_sum==0 (empty sequences from CUDAGraph padding) to avoid NaN + safe_e_sum = tl.where(e_sum == 0.0, 1.0, e_sum) + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + tl.where(e_sum == 0.0, 0.0, acc / safe_e_sum), + mask=mask_d, + ) + + +@enable_compat_on_triton_kernel +@triton.jit +def _compute_num_kv_splits_kernel( + num_kv_splits_ptr, + seq_lens_ptr, + num_seq: tl.constexpr, + max_kv_splits: tl.constexpr, + BLOCK: tl.constexpr, +): + """Compute number of KV splits per sequence based on seq_len.""" + idx = tl.arange(0, BLOCK) + mask = idx < num_seq + seq_len = tl.load(seq_lens_ptr + idx, mask=mask, other=0) + splits = (seq_len + 255) // 256 + splits = tl.minimum(splits, max_kv_splits) + splits = tl.maximum(splits, 1) + tl.store(num_kv_splits_ptr + idx, splits, mask=mask) + + +def compute_num_kv_splits(seq_lens, num_seq, max_kv_splits, out_buf=None): + """ + Compute number of KV splits per sequence. CUDA Graph compatible. + + Args: + seq_lens: [num_seq] int32 tensor of sequence lengths + num_seq: number of sequences + max_kv_splits: maximum number of splits + out_buf: Optional pre-allocated buffer. If provided, writes into it. + + Returns: + num_kv_splits: [num_seq] int32 tensor (or out_buf if provided) + """ + if out_buf is not None: + num_kv_splits = out_buf + else: + num_kv_splits = paddle.empty([num_seq], dtype="int32") + if num_seq == 0: + return num_kv_splits + BLOCK = triton.next_power_of_2(num_seq) + _compute_num_kv_splits_kernel[(1,)]( + num_kv_splits, seq_lens, num_seq=num_seq, max_kv_splits=max_kv_splits, BLOCK=BLOCK + ) + return num_kv_splits + + +def _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + att_lse, + kv_indptr, + kv_indices, + num_kv_splits, + max_kv_splits, + sm_scale, + kv_block_size, +): + """Launch stage 1 grouped decode attention kernel.""" + BLOCK = 32 + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + if Lk == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + elif Lk == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 + else: + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) + + batch, head_num = q.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[1] + + BLOCK_H = 16 + MAX_KV_SPLITS = max_kv_splits + grid = ( + batch, + triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), + MAX_KV_SPLITS, + ) + + _fwd_grouped_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + kv_indptr, + kv_indices, + att_out, + att_lse, + num_kv_splits, + q.strides[0], + q.strides[1], + k_buffer.strides[0], + k_buffer.strides[1], + k_buffer.strides[2], + v_buffer.strides[0], + v_buffer.strides[1], + v_buffer.strides[2], + att_out.strides[0], + att_out.strides[1], + att_out.strides[2], + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + BLOCK_H=BLOCK_H, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + KV_BLOCK_SIZE=kv_block_size, + num_warps=4, + num_stages=2, + Lk=Lk, + Lv=Lv, + ) + + +def _decode_softmax_reducev_fwd( + logits, + lse, + q, + o, + v_buffer, + kv_indptr, + num_kv_splits, + max_kv_splits, +): + """Launch stage 2 reduce kernel.""" + batch, head_num = q.shape[0], q.shape[1] + Lv = v_buffer.shape[-1] + BLOCK_DV = triton.next_power_of_2(Lv) + + MAX_KV_SPLITS = max_kv_splits + + grid = (batch, head_num) + _fwd_kernel_stage2[grid]( + logits, + lse, + o, + kv_indptr, + num_kv_splits, + logits.strides[0], + logits.strides[1], + logits.strides[2], + o.strides[0], + o.strides[1], + MAX_KV_SPLITS=MAX_KV_SPLITS, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + num_warps=4, + num_stages=2, + ) + + +def decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + kv_indptr, + kv_indices, + attn_logits, + attn_lse, + num_kv_splits, + max_kv_splits, + sm_scale, + kv_block_size, +): + """ + Triton decode attention for paged KV cache (split-KV approach). + + Args: + q: [batch, num_heads, Lk] query tensor + k_buffer: [num_blocks, kv_heads, block_size, Lk] paged key cache + v_buffer: [num_blocks, kv_heads, block_size, Lv] paged value cache + o: [batch, num_heads, Lv] output tensor + kv_indptr: [batch+1] CSR indptr for KV + kv_indices: [total_kv_len] flat token indices (block_id * block_size + offset) + attn_logits: [batch, num_heads, max_kv_splits, Lv] intermediate buffer + attn_lse: [batch, num_heads, max_kv_splits] intermediate lse buffer + num_kv_splits: [batch] number of splits per sequence + max_kv_splits: int, maximum number of splits + sm_scale: float, softmax scale + kv_block_size: int, the page block size + """ + _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + attn_lse, + kv_indptr, + kv_indices, + num_kv_splits, + max_kv_splits, + sm_scale, + kv_block_size, + ) + _decode_softmax_reducev_fwd( + attn_logits, + attn_lse, + q, + o, + v_buffer, + kv_indptr, + num_kv_splits, + max_kv_splits, + ) diff --git a/fastdeploy/model_executor/layers/attention/triton_ops/mla_cache_kernel.py b/fastdeploy/model_executor/layers/attention/triton_ops/mla_cache_kernel.py new file mode 100644 index 00000000000..a856ab12a53 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/triton_ops/mla_cache_kernel.py @@ -0,0 +1,147 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from +# https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/utils.py +# Licensed under Apache License 2.0 +# +# Triton kernel for writing MLA compressed KV cache into paged buffer. +""" + +import paddle +import triton +import triton.language as tl + +from fastdeploy.model_executor.ops.triton_ops.triton_utils import ( + enable_compat_on_triton_kernel, +) + + +@enable_compat_on_triton_kernel +@triton.jit +def _mla_write_cache_kernel( + compressed_kv_ptr, + k_pe_ptr, + cache_ptr, + slot_mapping_ptr, + stride_ckv_token, + stride_kpe_token, + stride_cache_block, + stride_cache_bs, + kv_lora_rank: tl.constexpr, + qk_rope_head_dim: tl.constexpr, + KV_BLOCK_SIZE: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCKS: tl.constexpr, +): + """ + Write [compressed_kv || k_pe] into paged cache at slot_mapping positions. + + cache layout: [num_blocks, 1, block_size, head_dim] + slot_mapping[i] = block_id * block_size + offset + """ + pid_token = tl.program_id(0) + pid_blk = tl.program_id(1) + + base = pid_blk * BLOCK + offs = base + tl.arange(0, BLOCK) + total_dim = kv_lora_rank + qk_rope_head_dim + mask = offs < total_dim + + slot = tl.load(slot_mapping_ptr + pid_token).to(tl.int64) + block_id = slot // KV_BLOCK_SIZE + offset_in_block = slot % KV_BLOCK_SIZE + + # Bounds check: skip if block_id is out of range + if block_id >= NUM_BLOCKS or block_id < 0: + return + + dst_ptr = cache_ptr + block_id * stride_cache_block + offset_in_block * stride_cache_bs + offs + + if base + BLOCK <= kv_lora_rank: + src = tl.load(compressed_kv_ptr + pid_token * stride_ckv_token + offs, mask=mask) + elif base >= kv_lora_rank: + offs_rope = offs - kv_lora_rank + src = tl.load(k_pe_ptr + pid_token * stride_kpe_token + offs_rope, mask=mask) + else: + is_nope = offs < kv_lora_rank + is_rope = (offs >= kv_lora_rank) & (offs < total_dim) + src_nope = tl.load( + compressed_kv_ptr + pid_token * stride_ckv_token + offs, + mask=mask & is_nope, + other=0, + ) + src_rope = tl.load( + k_pe_ptr + pid_token * stride_kpe_token + (offs - kv_lora_rank), + mask=mask & is_rope, + other=0, + ) + src = tl.where(is_nope, src_nope, src_rope) + + tl.store(dst_ptr, src, mask=mask) + + +def mla_write_cache_triton( + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + latent_cache: paddle.Tensor, + slot_mapping: paddle.Tensor, +): + """ + Write [compressed_kv || k_pe] into paged latent_cache at slot_mapping positions. + + Args: + compressed_kv: [num_tokens, kv_lora_rank] + k_pe: [num_tokens, 1, qk_rope_head_dim] or [num_tokens, qk_rope_head_dim] + latent_cache: [num_blocks, 1, block_size, kv_lora_rank + qk_rope_head_dim] + slot_mapping: [num_tokens] int64 + """ + num_tokens = compressed_kv.shape[0] + if num_tokens == 0: + return + + kv_lora_rank = compressed_kv.shape[-1] + k_pe_flat = k_pe.reshape([num_tokens, -1]) + qk_rope_head_dim = k_pe_flat.shape[-1] + total_dim = kv_lora_rank + qk_rope_head_dim + + kv_block_size = latent_cache.shape[2] + + BLOCK = 128 + grid = (num_tokens, triton.cdiv(total_dim, BLOCK)) + + # stride for cache: [num_blocks, 1, block_size, head_dim] + # stride_cache_block = 1 * block_size * head_dim + # stride_cache_bs = head_dim (stride along block_size dim) + stride_cache_block = latent_cache.strides[0] + stride_cache_bs = latent_cache.strides[2] + + num_blocks = latent_cache.shape[0] + + _mla_write_cache_kernel[grid]( + compressed_kv, + k_pe_flat, + latent_cache, + slot_mapping, + compressed_kv.strides[0], + k_pe_flat.strides[0], + stride_cache_block, + stride_cache_bs, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + KV_BLOCK_SIZE=kv_block_size, + BLOCK=BLOCK, + NUM_BLOCKS=num_blocks, + ) diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 07052c7faeb..11086963b0f 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -1105,7 +1105,8 @@ def forward( residual: paddle.Tensor, ): """ """ - if hidden_states.shape[0] > 0: + need_do_attention = forward_meta.max_len_tensor_cpu[1] > 0 or forward_meta.max_len_tensor_cpu[2] > 0 + if hidden_states.shape[0] > 0 and need_do_attention: hidden_states, residual = self.input_layernorm( hidden_states, residual_input=residual, forward_meta=forward_meta ) diff --git a/fastdeploy/platforms/base.py b/fastdeploy/platforms/base.py index bb30663492a..ee0554a5bc3 100644 --- a/fastdeploy/platforms/base.py +++ b/fastdeploy/platforms/base.py @@ -30,6 +30,7 @@ class _Backend(enum.Enum): PLAS_ATTN = enum.auto() HPU_ATTN = enum.auto() FLASH_MASK_ATTN = enum.auto() + TRITON_MLA_ATTN = enum.auto() class Platform: diff --git a/fastdeploy/platforms/cuda.py b/fastdeploy/platforms/cuda.py index acdf40d8fdb..cab62ac3f07 100644 --- a/fastdeploy/platforms/cuda.py +++ b/fastdeploy/platforms/cuda.py @@ -73,8 +73,11 @@ def get_attention_backend_cls(cls, selected_backend: _Backend): elif selected_backend == _Backend.FLASH_MASK_ATTN: logger.info("Using FLASH MASK ATTN backend.") return "fastdeploy.model_executor.layers.attention.FlashMaskAttentionBackend" + elif selected_backend == _Backend.TRITON_MLA_ATTN: + logger.info("Using TRITON MLA ATTN backend.") + return "fastdeploy.model_executor.layers.attention.TritonMLAAttentionBackend" else: raise ValueError( "Invalid attention backend you specified.\n" - "Now only support [NATIVE_ATTN, MLA_ATTN, APPEND_ATTN] in cuda place." + "Now only support [NATIVE_ATTN, MLA_ATTN, APPEND_ATTN, TRITON_MLA_ATTN] in cuda place." ) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 7bcd61bdade..8ba89be1b33 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -51,6 +51,9 @@ from fastdeploy.model_executor.layers.attention.mla_attention_backend import ( MLAAttentionBackend, ) +from fastdeploy.model_executor.layers.attention.triton_mla_attention_backend import ( + TritonMLAAttentionBackend, +) from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( RoutingReplayManager, ) @@ -285,7 +288,7 @@ def __init__( # NOTE:(changwenbin) Determine whether it is Multi-Head Latent Attention, # To rationalize the allocation of kvcache. - self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN" + self.mla_cache = envs.FD_ATTENTION_BACKEND in ("MLA_ATTN", "TRITON_MLA_ATTN") self.dsa_cache = envs.FD_ATTENTION_BACKEND == "DSA_ATTN" self.enable_cache_manager_v1 = envs.ENABLE_V1_KVCACHE_MANAGER @@ -1382,8 +1385,10 @@ def _compute_position_ids_and_slot_mapping(self) -> None: applicable to all models that need per-token KV cache physical slot addresses. Results are stored in self.forward_meta. """ - # NOTE(zhushengguang): Only support MLAAttentionBackend and DSAAttentionBackend currently. - if not isinstance(self.attn_backends[0], (MLAAttentionBackend, DSAAttentionBackend)): + # Only works on MLAAttentionBackend, DSAAttentionBackend, TritonMLAAttentionBackend + if not isinstance( + self.attn_backends[0], (MLAAttentionBackend, DSAAttentionBackend, TritonMLAAttentionBackend) + ): return current_total_tokens = self.forward_meta.ids_remove_padding.shape[0] position_ids = self.share_inputs["position_ids_buffer"][:current_total_tokens] diff --git a/scripts/.coveragerc b/scripts/.coveragerc index 5f1cb5e6e1d..00d4efa45ea 100644 --- a/scripts/.coveragerc +++ b/scripts/.coveragerc @@ -31,6 +31,8 @@ omit = */fastdeploy/benchmarks/lib/endpoint_request_func.py */fastdeploy/model_executor/graph_optimization/utils.py */fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py + */fastdeploy/model_executor/layers/attention/triton_ops/* + */fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py */fastdeploy/model_executor/ops/gpu/fastdeploy_ops.py */fastdeploy/model_executor/ops/gpu/fastdeploy_ops/__init__.py */fastdeploy/model_executor/ops/gpu/deep_gemm/utils.py diff --git a/tests/deterministic/test_triton_decode_attention.py b/tests/deterministic/test_triton_decode_attention.py new file mode 100644 index 00000000000..06f5b438beb --- /dev/null +++ b/tests/deterministic/test_triton_decode_attention.py @@ -0,0 +1,429 @@ +""" +Triton decode attention kernel tests — correctness, determinism, and edge cases. + +Tests for compute_num_kv_splits and decode_attention_fwd from +fastdeploy/model_executor/layers/attention/triton_ops/decode_attention.py + +Correctness verification strategy: + A naive Python reference implementation (using float32 matmul + softmax) is + compared against the Triton kernel output using max absolute diff and cosine + similarity thresholds. Broad parametrized coverage spans head configurations + (MHA/GQA/MQA), data types (float16/bfloat16), sequence lengths, and edge cases. + +Test scenarios: +1. compute_num_kv_splits: basic, edge cases, max capping +2. decode_attention_fwd correctness: MHA/GQA/MQA, various head_dim, float16/bfloat16 +3. Determinism: multiple runs produce identical results +4. Edge cases: single token seq, large seq, multiple batches + +Usage: + CUDA_VISIBLE_DEVICES=0 python -m pytest tests/deterministic/test_triton_decode_attention.py -v +""" + +import numpy as np +import paddle +import pytest + +from fastdeploy.model_executor.layers.attention.triton_ops.decode_attention import ( + compute_num_kv_splits, + decode_attention_fwd, +) + +# --------------------------------------------------------------------------- +# Skip if no CUDA +# --------------------------------------------------------------------------- +pytestmark = pytest.mark.skipif( + not paddle.is_compiled_with_cuda() or paddle.device.cuda.device_count() == 0, + reason="Triton decode attention requires CUDA", +) + +# --------------------------------------------------------------------------- +# Tolerance constants +# --------------------------------------------------------------------------- +FP16_ATOL = 2e-2 +BF16_ATOL = 5e-2 +COSINE_SIM_THRESHOLD = 1 - 1e-3 + + +# --------------------------------------------------------------------------- +# Metric helpers +# --------------------------------------------------------------------------- +def cosine_similarity(a, b): + """Compute cosine similarity between two tensors (flattened).""" + a_flat = a.astype("float32").reshape([-1]) + b_flat = b.astype("float32").reshape([-1]) + dot = float(paddle.sum(a_flat * b_flat).item()) + norm_a = float(paddle.sqrt(paddle.sum(a_flat * a_flat)).item()) + norm_b = float(paddle.sqrt(paddle.sum(b_flat * b_flat)).item()) + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + + +# --------------------------------------------------------------------------- +# Reference implementation: naive decode attention (no paging) +# --------------------------------------------------------------------------- +def naive_decode_attention_ref(q, k_pages, v_pages, kv_indptr, kv_indices, sm_scale, kv_block_size): + """ + Naive Python reference for decode attention with paged KV cache. + + Args: + q: [batch, num_heads, Lk] + k_pages: [num_blocks, kv_heads, block_size, Lk] + v_pages: [num_blocks, kv_heads, block_size, Lv] + kv_indptr: [batch+1] CSR pointers + kv_indices: [total_kv_len] flat indices (block_id * block_size + offset) + sm_scale: float + kv_block_size: int + + Returns: + o: [batch, num_heads, Lv] + """ + batch = q.shape[0] + num_heads = q.shape[1] + Lk = q.shape[2] + kv_heads = k_pages.shape[1] + Lv = v_pages.shape[-1] + group_size = num_heads // kv_heads + + q_np = q.astype("float32").numpy() + k_np = k_pages.astype("float32").numpy() + v_np = v_pages.astype("float32").numpy() + indptr_np = kv_indptr.numpy() + indices_np = kv_indices.numpy() + + o_np = np.zeros([batch, num_heads, Lv], dtype=np.float32) + + for b in range(batch): + start = indptr_np[b] + end = indptr_np[b + 1] + seq_len = end - start + if seq_len == 0: + continue + + # Gather K and V for this batch from paged cache + k_gathered = np.zeros([kv_heads, seq_len, Lk], dtype=np.float32) + v_gathered = np.zeros([kv_heads, seq_len, Lv], dtype=np.float32) + + for t in range(seq_len): + flat_idx = indices_np[start + t] + block_id = flat_idx // kv_block_size + offset = flat_idx % kv_block_size + k_gathered[:, t, :] = k_np[block_id, :, offset, :] + v_gathered[:, t, :] = v_np[block_id, :, offset, :] + + # Expand KV for GQA + for h in range(num_heads): + kv_h = h // group_size + # q_h: [Lk], k_h: [seq_len, Lk], v_h: [seq_len, Lv] + q_h = q_np[b, h, :] + k_h = k_gathered[kv_h] # [seq_len, Lk] + v_h = v_gathered[kv_h] # [seq_len, Lv] + + # Attention scores + scores = (q_h @ k_h.T) * sm_scale # [seq_len] + scores -= np.max(scores) + attn_weights = np.exp(scores) + attn_weights /= np.sum(attn_weights) + 1e-12 + + o_np[b, h, :] = attn_weights @ v_h + + return paddle.to_tensor(o_np) + + +# --------------------------------------------------------------------------- +# Helper: build paged KV cache test data +# --------------------------------------------------------------------------- +def build_decode_test_data( + batch_size, + num_heads, + kv_heads, + head_dim_k, + head_dim_v, + seq_lens, + block_size=16, + dtype="float16", + seed=42, +): + """ + Build test data for decode attention. + + Returns dict with all tensors needed for decode_attention_fwd and the reference. + """ + np.random.seed(seed) + paddle.seed(seed) + + num_blocks_needed = sum((s + block_size - 1) // block_size for s in seq_lens) + num_blocks = max(num_blocks_needed + 4, 8) + + # Paged K/V cache + k_pages = paddle.randn([num_blocks, kv_heads, block_size, head_dim_k]).astype(dtype) + v_pages = paddle.randn([num_blocks, kv_heads, block_size, head_dim_v]).astype(dtype) + + # Allocate blocks sequentially for simplicity + block_cursor = 0 + kv_indptr_list = [0] + kv_indices_list = [] + + for b in range(batch_size): + sl = seq_lens[b] + for t in range(sl): + blk_idx_in_seq = t // block_size + offset = t % block_size + actual_block_id = block_cursor + blk_idx_in_seq + kv_indices_list.append(actual_block_id * block_size + offset) + block_cursor += (sl + block_size - 1) // block_size + kv_indptr_list.append(kv_indptr_list[-1] + sl) + + kv_indptr = paddle.to_tensor(kv_indptr_list, dtype="int32") + kv_indices = paddle.to_tensor(kv_indices_list, dtype="int32") + + # Query + q = paddle.randn([batch_size, num_heads, head_dim_k]).astype(dtype) + + # Compute num_kv_splits + seq_lens_tensor = paddle.to_tensor(seq_lens, dtype="int32") + max_kv_splits = 32 + num_kv_splits = compute_num_kv_splits(seq_lens_tensor, batch_size, max_kv_splits) + + # Pre-allocate intermediate buffers + Lv = head_dim_v + attn_logits = paddle.empty([batch_size, num_heads, max_kv_splits, Lv], dtype="float32") + attn_lse = paddle.empty([batch_size, num_heads, max_kv_splits], dtype="float32") + o = paddle.empty([batch_size, num_heads, Lv], dtype=dtype) + + sm_scale = head_dim_k**-0.5 + + return { + "q": q, + "k_pages": k_pages, + "v_pages": v_pages, + "o": o, + "kv_indptr": kv_indptr, + "kv_indices": kv_indices, + "attn_logits": attn_logits, + "attn_lse": attn_lse, + "num_kv_splits": num_kv_splits, + "max_kv_splits": max_kv_splits, + "sm_scale": sm_scale, + "block_size": block_size, + "seq_lens": seq_lens, + } + + +# =========================================================================== +# Tests for compute_num_kv_splits +# =========================================================================== +class TestComputeNumKvSplits: + """Tests for the compute_num_kv_splits utility.""" + + def test_basic(self): + """Short sequences should get 1 split.""" + seq_lens = paddle.to_tensor([100, 200, 50], dtype="int32") + splits = compute_num_kv_splits(seq_lens, 3, max_kv_splits=32) + splits_np = splits[:3].numpy() + # (100+255)//256 = 1, (200+255)//256 = 1, (50+255)//256 = 1 + np.testing.assert_array_equal(splits_np, [1, 1, 1]) + + def test_long_sequences(self): + """Longer sequences should get more splits.""" + seq_lens = paddle.to_tensor([512, 1024, 2048], dtype="int32") + splits = compute_num_kv_splits(seq_lens, 3, max_kv_splits=32) + splits_np = splits[:3].numpy() + expected = [min((s + 255) // 256, 32) for s in [512, 1024, 2048]] + np.testing.assert_array_equal(splits_np, expected) + + def test_max_capping(self): + """Splits should be capped at max_kv_splits.""" + seq_lens = paddle.to_tensor([100000], dtype="int32") + splits = compute_num_kv_splits(seq_lens, 1, max_kv_splits=16) + assert splits[0].item() == 16 + + def test_single_token(self): + """Single-token sequence should get 1 split.""" + seq_lens = paddle.to_tensor([1], dtype="int32") + splits = compute_num_kv_splits(seq_lens, 1, max_kv_splits=32) + assert splits[0].item() == 1 + + def test_out_buf(self): + """Pre-allocated output buffer should be respected.""" + seq_lens = paddle.to_tensor([512, 1024], dtype="int32") + out_buf = paddle.zeros([2], dtype="int32") + result = compute_num_kv_splits(seq_lens, 2, max_kv_splits=32, out_buf=out_buf) + # Result should be the same object (same data pointer) + assert result.data_ptr() == out_buf.data_ptr() + assert result[0].item() == 2 # (512+255)//256 = 2 + assert result[1].item() == 4 # (1024+255)//256 = 4 + + def test_empty(self): + """num_seq=0 should return without error.""" + seq_lens = paddle.empty([0], dtype="int32") + result = compute_num_kv_splits(seq_lens, 0, max_kv_splits=32) + assert result.shape[0] == 0 + + +# =========================================================================== +# Tests for decode_attention_fwd +# =========================================================================== + +# MLA typical configs: Lk = kv_lora_rank + qk_rope_head_dim (e.g. 512+64=576) +_DECODE_CASES = [ + # (name, batch, num_heads, kv_heads, Lk, Lv, seq_lens, block_size) + ("mla_basic_bs1", 1, 16, 1, 576, 512, [64], 16), + ("mla_basic_bs4", 4, 16, 1, 576, 512, [32, 64, 128, 48], 16), + ("mla_long_seq", 1, 16, 1, 576, 512, [1024], 16), + ("mla_short_seq", 2, 16, 1, 576, 512, [1, 3], 16), + ("mla_bs8_mixed", 8, 8, 1, 576, 512, [16, 32, 64, 128, 256, 48, 96, 512], 16), + ("gqa_basic", 2, 16, 4, 128, 128, [64, 128], 16), + ("mha_basic", 2, 8, 8, 128, 128, [32, 64], 16), + ("mla_288", 2, 16, 1, 288, 256, [64, 128], 16), + ("mla_block32", 2, 16, 1, 576, 512, [64, 128], 32), +] + + +@pytest.mark.parametrize( + "name,batch,num_heads,kv_heads,Lk,Lv,seq_lens,block_size", + _DECODE_CASES, + ids=[c[0] for c in _DECODE_CASES], +) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) +def test_decode_attention_correctness(name, batch, num_heads, kv_heads, Lk, Lv, seq_lens, block_size, dtype): + """Triton decode attention output should match naive reference.""" + data = build_decode_test_data( + batch_size=batch, + num_heads=num_heads, + kv_heads=kv_heads, + head_dim_k=Lk, + head_dim_v=Lv, + seq_lens=seq_lens, + block_size=block_size, + dtype=dtype, + ) + + # Run triton kernel + decode_attention_fwd( + data["q"], + data["k_pages"], + data["v_pages"], + data["o"], + data["kv_indptr"], + data["kv_indices"], + data["attn_logits"], + data["attn_lse"], + data["num_kv_splits"], + data["max_kv_splits"], + data["sm_scale"], + data["block_size"], + ) + triton_out = data["o"].astype("float32") + + # Run reference + ref_out = naive_decode_attention_ref( + data["q"], + data["k_pages"], + data["v_pages"], + data["kv_indptr"], + data["kv_indices"], + data["sm_scale"], + data["block_size"], + ) + + max_diff = float(paddle.max(paddle.abs(triton_out - ref_out)).item()) + cos_sim = cosine_similarity(triton_out, ref_out) + + atol = BF16_ATOL if dtype == "bfloat16" else FP16_ATOL + assert max_diff < atol, f"[{name}/{dtype}] max_diff={max_diff:.6f} exceeds atol={atol}" + assert ( + cos_sim > COSINE_SIM_THRESHOLD + ), f"[{name}/{dtype}] cos_sim={cos_sim:.6f} below threshold={COSINE_SIM_THRESHOLD}" + + +# =========================================================================== +# Determinism test +# =========================================================================== +def test_decode_attention_determinism(): + """Multiple runs should produce bitwise identical results.""" + data = build_decode_test_data( + batch_size=4, + num_heads=16, + kv_heads=1, + head_dim_k=576, + head_dim_v=512, + seq_lens=[64, 128, 32, 256], + block_size=16, + dtype="float16", + ) + + results = [] + for _ in range(5): + o = paddle.empty_like(data["o"]) + decode_attention_fwd( + data["q"], + data["k_pages"], + data["v_pages"], + o, + data["kv_indptr"], + data["kv_indices"], + data["attn_logits"], + data["attn_lse"], + data["num_kv_splits"], + data["max_kv_splits"], + data["sm_scale"], + data["block_size"], + ) + results.append(o.astype("float32").numpy()) + + for i in range(1, len(results)): + np.testing.assert_array_equal(results[0], results[i], err_msg=f"Run 0 vs run {i} differ — non-deterministic!") + + +# =========================================================================== +# Edge case: all sequences same length +# =========================================================================== +def test_decode_attention_uniform_seqlens(): + """Uniform sequence lengths should produce correct results.""" + batch = 4 + seq_len = 128 + data = build_decode_test_data( + batch_size=batch, + num_heads=16, + kv_heads=1, + head_dim_k=576, + head_dim_v=512, + seq_lens=[seq_len] * batch, + block_size=16, + dtype="float16", + ) + + decode_attention_fwd( + data["q"], + data["k_pages"], + data["v_pages"], + data["o"], + data["kv_indptr"], + data["kv_indices"], + data["attn_logits"], + data["attn_lse"], + data["num_kv_splits"], + data["max_kv_splits"], + data["sm_scale"], + data["block_size"], + ) + triton_out = data["o"].astype("float32") + + ref_out = naive_decode_attention_ref( + data["q"], + data["k_pages"], + data["v_pages"], + data["kv_indptr"], + data["kv_indices"], + data["sm_scale"], + data["block_size"], + ) + + max_diff = float(paddle.max(paddle.abs(triton_out - ref_out)).item()) + assert max_diff < FP16_ATOL, f"max_diff={max_diff:.6f}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/deterministic/test_triton_mla_cache_kernel.py b/tests/deterministic/test_triton_mla_cache_kernel.py new file mode 100644 index 00000000000..3396d4fec47 --- /dev/null +++ b/tests/deterministic/test_triton_mla_cache_kernel.py @@ -0,0 +1,508 @@ +""" +Triton MLA cache write kernel tests — correctness, determinism, and edge cases. + +Tests for mla_write_cache_triton from +fastdeploy/model_executor/layers/attention/triton_ops/mla_cache_kernel.py + +Correctness verification strategy: + A naive Python reference implementation writes [compressed_kv || k_pe] into + the paged cache at slot_mapping positions. The Triton kernel output is compared + against the reference using exact bitwise equality (pure copy, no arithmetic). + +Test scenarios: +1. Basic write: single batch, sequential slots +2. Multi-batch: multiple batches with different slot positions +3. Non-contiguous slots: random slot assignments across blocks +4. k_pe with extra dim: [num_tokens, 1, qk_rope_head_dim] shape +5. Determinism: multiple runs produce identical cache +6. Edge cases: single token, large batch, unaligned dims + +Usage: + CUDA_VISIBLE_DEVICES=0 python -m pytest tests/deterministic/test_triton_mla_cache_kernel.py -v +""" + +import numpy as np +import paddle +import pytest + +from fastdeploy.model_executor.layers.attention.triton_ops.mla_cache_kernel import ( + mla_write_cache_triton, +) + +# --------------------------------------------------------------------------- +# Skip if no CUDA +# --------------------------------------------------------------------------- +pytestmark = pytest.mark.skipif( + not paddle.is_compiled_with_cuda() or paddle.device.cuda.device_count() == 0, + reason="Triton MLA cache kernel requires CUDA", +) + +# --------------------------------------------------------------------------- +# Typical MLA dimensions +# --------------------------------------------------------------------------- +KV_LORA_RANK = 512 +QK_ROPE_HEAD_DIM = 64 +BLOCK_SIZE = 16 + + +# --------------------------------------------------------------------------- +# Reference implementation +# --------------------------------------------------------------------------- +def mla_write_cache_ref(compressed_kv, k_pe, latent_cache, slot_mapping): + """ + Naive Python reference for MLA cache write. + + Writes [compressed_kv || k_pe] into paged latent_cache at slot_mapping positions. + + Args: + compressed_kv: [num_tokens, kv_lora_rank] + k_pe: [num_tokens, 1, qk_rope_head_dim] or [num_tokens, qk_rope_head_dim] + latent_cache: [num_blocks, 1, block_size, kv_lora_rank + qk_rope_head_dim] + slot_mapping: [num_tokens] int64 + """ + num_tokens = compressed_kv.shape[0] + kv_lora_rank = compressed_kv.shape[-1] + + ckv_np = compressed_kv.astype("float32").numpy() + kpe_np = k_pe.reshape([num_tokens, -1]).astype("float32").numpy() + cache_np = latent_cache.astype("float32").numpy() + slots_np = slot_mapping.numpy() + + kv_block_size = latent_cache.shape[2] + + for i in range(num_tokens): + slot = int(slots_np[i]) + block_id = slot // kv_block_size + offset = slot % kv_block_size + # Write [compressed_kv || k_pe] into cache + cache_np[block_id, 0, offset, :kv_lora_rank] = ckv_np[i] + cache_np[block_id, 0, offset, kv_lora_rank:] = kpe_np[i] + + return paddle.to_tensor(cache_np) + + +# --------------------------------------------------------------------------- +# Helper: build test data +# --------------------------------------------------------------------------- +def build_cache_test_data( + num_tokens, + kv_lora_rank=KV_LORA_RANK, + qk_rope_head_dim=QK_ROPE_HEAD_DIM, + block_size=BLOCK_SIZE, + num_blocks=None, + dtype="bfloat16", + kpe_3d=False, + seed=42, + slot_mapping=None, +): + """Build test data for mla_write_cache_triton.""" + np.random.seed(seed) + paddle.seed(seed) + + if num_blocks is None: + num_blocks = max((num_tokens + block_size - 1) // block_size + 4, 8) + + latent_dim = kv_lora_rank + qk_rope_head_dim + compressed_kv = paddle.randn([num_tokens, kv_lora_rank]).astype(dtype) + + if kpe_3d: + k_pe = paddle.randn([num_tokens, 1, qk_rope_head_dim]).astype(dtype) + else: + k_pe = paddle.randn([num_tokens, qk_rope_head_dim]).astype(dtype) + + latent_cache = paddle.zeros([num_blocks, 1, block_size, latent_dim]).astype(dtype) + + if slot_mapping is None: + slot_mapping = paddle.arange(num_tokens, dtype="int64") + else: + slot_mapping = paddle.to_tensor(slot_mapping, dtype="int64") + + return { + "compressed_kv": compressed_kv, + "k_pe": k_pe, + "latent_cache": latent_cache, + "slot_mapping": slot_mapping, + "kv_lora_rank": kv_lora_rank, + "qk_rope_head_dim": qk_rope_head_dim, + "block_size": block_size, + } + + +# =========================================================================== +# Basic correctness tests +# =========================================================================== +class TestMLAWriteCacheBasic: + """Basic correctness tests for mla_write_cache_triton.""" + + def test_sequential_slots(self): + """Sequential slot mapping: tokens 0..N-1 map to slots 0..N-1.""" + data = build_cache_test_data(num_tokens=16) + + cache_ref = mla_write_cache_ref( + data["compressed_kv"], + data["k_pe"], + data["latent_cache"].clone(), + data["slot_mapping"], + ) + + mla_write_cache_triton( + data["compressed_kv"], + data["k_pe"], + data["latent_cache"], + data["slot_mapping"], + ) + + np.testing.assert_allclose( + data["latent_cache"].astype("float32").numpy(), + cache_ref.numpy(), + atol=1e-6, + err_msg="Sequential slots: triton vs ref mismatch", + ) + + def test_single_token(self): + """Single token write.""" + data = build_cache_test_data(num_tokens=1) + + cache_ref = mla_write_cache_ref( + data["compressed_kv"], + data["k_pe"], + data["latent_cache"].clone(), + data["slot_mapping"], + ) + + mla_write_cache_triton( + data["compressed_kv"], + data["k_pe"], + data["latent_cache"], + data["slot_mapping"], + ) + + np.testing.assert_allclose( + data["latent_cache"].astype("float32").numpy(), + cache_ref.numpy(), + atol=1e-6, + ) + + def test_large_batch(self): + """Large batch of tokens.""" + data = build_cache_test_data(num_tokens=512, num_blocks=64) + + cache_ref = mla_write_cache_ref( + data["compressed_kv"], + data["k_pe"], + data["latent_cache"].clone(), + data["slot_mapping"], + ) + + mla_write_cache_triton( + data["compressed_kv"], + data["k_pe"], + data["latent_cache"], + data["slot_mapping"], + ) + + np.testing.assert_allclose( + data["latent_cache"].astype("float32").numpy(), + cache_ref.numpy(), + atol=1e-6, + ) + + +# =========================================================================== +# Non-contiguous slot tests +# =========================================================================== +class TestMLAWriteCacheNonContiguous: + """Tests with non-sequential slot mappings.""" + + def test_scattered_slots(self): + """Slots scattered across multiple blocks.""" + num_tokens = 8 + block_size = 4 + num_blocks = 16 + # Scatter tokens across different blocks + slots = [0, 5, 10, 15, 20, 25, 30, 35] + + data = build_cache_test_data( + num_tokens=num_tokens, + block_size=block_size, + num_blocks=num_blocks, + slot_mapping=slots, + ) + + cache_ref = mla_write_cache_ref( + data["compressed_kv"], + data["k_pe"], + data["latent_cache"].clone(), + data["slot_mapping"], + ) + + mla_write_cache_triton( + data["compressed_kv"], + data["k_pe"], + data["latent_cache"], + data["slot_mapping"], + ) + + np.testing.assert_allclose( + data["latent_cache"].astype("float32").numpy(), + cache_ref.numpy(), + atol=1e-6, + err_msg="Scattered slots: triton vs ref mismatch", + ) + + def test_random_slots(self): + """Random slot assignments.""" + num_tokens = 32 + block_size = 8 + num_blocks = 32 + np.random.seed(123) + # Generate unique random slots within valid range + max_slot = num_blocks * block_size + slots = np.random.choice(max_slot, size=num_tokens, replace=False).tolist() + + data = build_cache_test_data( + num_tokens=num_tokens, + block_size=block_size, + num_blocks=num_blocks, + slot_mapping=slots, + ) + + cache_ref = mla_write_cache_ref( + data["compressed_kv"], + data["k_pe"], + data["latent_cache"].clone(), + data["slot_mapping"], + ) + + mla_write_cache_triton( + data["compressed_kv"], + data["k_pe"], + data["latent_cache"], + data["slot_mapping"], + ) + + np.testing.assert_allclose( + data["latent_cache"].astype("float32").numpy(), + cache_ref.numpy(), + atol=1e-6, + err_msg="Random slots: triton vs ref mismatch", + ) + + +# =========================================================================== +# k_pe shape variants +# =========================================================================== +class TestMLAWriteCacheKpeShape: + """Tests for different k_pe tensor shapes.""" + + def test_kpe_2d(self): + """k_pe shape: [num_tokens, qk_rope_head_dim].""" + data = build_cache_test_data(num_tokens=16, kpe_3d=False) + + cache_ref = mla_write_cache_ref( + data["compressed_kv"], + data["k_pe"], + data["latent_cache"].clone(), + data["slot_mapping"], + ) + + mla_write_cache_triton( + data["compressed_kv"], + data["k_pe"], + data["latent_cache"], + data["slot_mapping"], + ) + + np.testing.assert_allclose( + data["latent_cache"].astype("float32").numpy(), + cache_ref.numpy(), + atol=1e-6, + ) + + def test_kpe_3d(self): + """k_pe shape: [num_tokens, 1, qk_rope_head_dim].""" + data = build_cache_test_data(num_tokens=16, kpe_3d=True) + + cache_ref = mla_write_cache_ref( + data["compressed_kv"], + data["k_pe"], + data["latent_cache"].clone(), + data["slot_mapping"], + ) + + mla_write_cache_triton( + data["compressed_kv"], + data["k_pe"], + data["latent_cache"], + data["slot_mapping"], + ) + + np.testing.assert_allclose( + data["latent_cache"].astype("float32").numpy(), + cache_ref.numpy(), + atol=1e-6, + ) + + +# =========================================================================== +# Parametrized dtype/dimension tests +# =========================================================================== +_DIMENSION_CASES = [ + # (name, kv_lora_rank, qk_rope_head_dim) — typical MLA configs + ("deepseek_v3", 512, 64), + ("deepseek_v2_lite", 256, 32), + ("small_dims", 128, 32), +] + + +@pytest.mark.parametrize( + "name,kv_lora_rank,qk_rope_head_dim", + _DIMENSION_CASES, + ids=[c[0] for c in _DIMENSION_CASES], +) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) +def test_write_cache_dimensions(name, kv_lora_rank, qk_rope_head_dim, dtype): + """Test cache write across different MLA dimension configurations and dtypes.""" + data = build_cache_test_data( + num_tokens=32, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + dtype=dtype, + ) + + cache_ref = mla_write_cache_ref( + data["compressed_kv"], + data["k_pe"], + data["latent_cache"].clone(), + data["slot_mapping"], + ) + + mla_write_cache_triton( + data["compressed_kv"], + data["k_pe"], + data["latent_cache"], + data["slot_mapping"], + ) + + np.testing.assert_allclose( + data["latent_cache"].astype("float32").numpy(), + cache_ref.numpy(), + atol=1e-6, + err_msg=f"[{name}/{dtype}] triton vs ref mismatch", + ) + + +# =========================================================================== +# Determinism test +# =========================================================================== +def test_write_cache_determinism(): + """Multiple runs should produce bitwise identical cache contents.""" + data = build_cache_test_data(num_tokens=64) + + results = [] + for _ in range(5): + cache = data["latent_cache"].clone() + mla_write_cache_triton( + data["compressed_kv"], + data["k_pe"], + cache, + data["slot_mapping"], + ) + results.append(cache.astype("float32").numpy()) + + for i in range(1, len(results)): + np.testing.assert_array_equal(results[0], results[i], err_msg=f"Run 0 vs run {i} differ — non-deterministic!") + + +# =========================================================================== +# Manual baseline test (hand-crafted small tensors) +# =========================================================================== +def test_manual_baseline(): + """Hand-crafted small tensors to verify exact values end up in the right cache slots.""" + kv_lora_rank = 4 + qk_rope_head_dim = 2 + block_size = 2 + num_blocks = 4 + latent_dim = kv_lora_rank + qk_rope_head_dim # 6 + + # 3 tokens, deterministic values + compressed_kv = paddle.to_tensor( + [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + ], + dtype="float32", + ) + + k_pe = paddle.to_tensor( + [ + [0.1, 0.2], + [0.3, 0.4], + [0.5, 0.6], + ], + dtype="float32", + ) + + latent_cache = paddle.zeros([num_blocks, 1, block_size, latent_dim], dtype="float32") + + # slot_mapping: token 0 -> slot 0 (block 0, offset 0) + # token 1 -> slot 3 (block 1, offset 1) + # token 2 -> slot 5 (block 2, offset 1) + slot_mapping = paddle.to_tensor([0, 3, 5], dtype="int64") + + mla_write_cache_triton(compressed_kv, k_pe, latent_cache, slot_mapping) + + cache_np = latent_cache.numpy() + + # Token 0 -> block 0, offset 0: [1, 2, 3, 4, 0.1, 0.2] + expected_0 = np.array([1.0, 2.0, 3.0, 4.0, 0.1, 0.2], dtype=np.float32) + np.testing.assert_allclose(cache_np[0, 0, 0, :], expected_0, atol=1e-6) + + # Token 1 -> block 1, offset 1: [5, 6, 7, 8, 0.3, 0.4] + expected_1 = np.array([5.0, 6.0, 7.0, 8.0, 0.3, 0.4], dtype=np.float32) + np.testing.assert_allclose(cache_np[1, 0, 1, :], expected_1, atol=1e-6) + + # Token 2 -> block 2, offset 1: [9, 10, 11, 12, 0.5, 0.6] + expected_2 = np.array([9.0, 10.0, 11.0, 12.0, 0.5, 0.6], dtype=np.float32) + np.testing.assert_allclose(cache_np[2, 0, 1, :], expected_2, atol=1e-6) + + # All other slots should remain zero + zero_slots = [ + (0, 0, 1), # block 0, offset 1 + (1, 0, 0), # block 1, offset 0 + (2, 0, 0), # block 2, offset 0 + (3, 0, 0), # block 3, offset 0 + (3, 0, 1), # block 3, offset 1 + ] + for blk, head, off in zero_slots: + np.testing.assert_array_equal( + cache_np[blk, head, off, :], + np.zeros(latent_dim, dtype=np.float32), + err_msg=f"Slot ({blk},{head},{off}) should be zero", + ) + + +# =========================================================================== +# Empty token test +# =========================================================================== +def test_empty_tokens(): + """Zero tokens should be a no-op (no crash).""" + latent_dim = KV_LORA_RANK + QK_ROPE_HEAD_DIM + compressed_kv = paddle.empty([0, KV_LORA_RANK], dtype="bfloat16") + k_pe = paddle.empty([0, QK_ROPE_HEAD_DIM], dtype="bfloat16") + latent_cache = paddle.zeros([4, 1, BLOCK_SIZE, latent_dim], dtype="bfloat16") + slot_mapping = paddle.empty([0], dtype="int64") + + # Should not crash + mla_write_cache_triton(compressed_kv, k_pe, latent_cache, slot_mapping) + + # Cache should remain all zeros + np.testing.assert_array_equal( + latent_cache.astype("float32").numpy(), + np.zeros_like(latent_cache.astype("float32").numpy()), + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])