diff --git a/fastdeploy/cache_manager/gdn_state_pool.py b/fastdeploy/cache_manager/gdn_state_pool.py new file mode 100644 index 00000000000..2c8716e4cf2 --- /dev/null +++ b/fastdeploy/cache_manager/gdn_state_pool.py @@ -0,0 +1,229 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +GDN (Gated Delta Network) State Pool — pre-allocated GPU tensor pool +for conv and SSM states used by Qwen3.5 linear attention layers. + +Design: + - Analogous to paged KV cache block pool, but each slot stores a + complete per-request state (conv state + SSM state). + - All GDN layers share a single pool object, indexed by layer_idx. + - Slot 0 is reserved as a zero-filled padding sentinel. + PAD_SLOT_ID (-1) is mapped to slot 0 via +1 offset when building + gdn_slot_ids in ForwardMeta, so reads return zero and writes are + harmless. + +Pool layouts: + conv_pool: [num_gdn_layers, pool_size, conv_dim, conv_kernel_size - 1] + ssm_pool: [num_gdn_layers, pool_size, num_v_heads, head_k_dim, head_v_dim] + + where pool_size = max_num_seqs + 1 (slot 0 = padding sentinel). +""" + +import logging +from typing import List + +import paddle + +logger = logging.getLogger(__name__) + +PAD_SLOT_ID = -1 + + +class GDNSlotAllocator: + """Lightweight CPU-only slot allocator for GDN state pool. + + Used by ResourceManagerV1 on the scheduler side to manage slot IDs + without requiring paddle/GPU access. The corresponding GPU tensors + live in GDNStatePool on the worker side. + + Slot 0 is reserved as a padding sentinel. Valid slots: 1..max_num_seqs. + """ + + def __init__(self, max_num_seqs: int): + self.max_num_seqs = max_num_seqs + self._free_slots: List[int] = list(range(max_num_seqs, 0, -1)) + + def allocate(self) -> int: + """Allocate a single slot ID. + + Returns: + Allocated slot ID (1-based). + + Raises: + RuntimeError: If no free slots available. + """ + if not self._free_slots: + raise RuntimeError(f"GDNSlotAllocator: no free slots (max_num_seqs={self.max_num_seqs})") + return self._free_slots.pop() + + def free(self, slot_id: int): + """Return a slot ID to the free list. + + Args: + slot_id: Slot ID to free (1-based). Slot 0 is silently ignored. + """ + if slot_id > 0: + self._free_slots.append(slot_id) + + @property + def num_free_slots(self) -> int: + return len(self._free_slots) + + +class GDNStatePool: + """Pre-allocated GPU tensor pool for GDN conv and SSM states. + + Args: + max_num_seqs: Maximum number of concurrent sequences. + num_gdn_layers: Number of GDN (linear_attention) layers in the model. + conv_dim: TP-local convolution dimension (key_dim * 2 + value_dim) // tp_size. + conv_kernel_size: Causal conv1d kernel width (e.g. 4). + num_v_heads: TP-local number of value heads (num_v_heads // tp_size). + head_k_dim: Per-head key dimension. + head_v_dim: Per-head value dimension. + conv_dtype: Data type for conv state pool (default: bfloat16). + """ + + def __init__( + self, + max_num_seqs: int, + num_gdn_layers: int, + conv_dim: int, + conv_kernel_size: int, + num_v_heads: int, + head_k_dim: int, + head_v_dim: int, + conv_dtype: paddle.dtype = paddle.bfloat16, + ): + self.max_num_seqs = max_num_seqs + self.num_gdn_layers = num_gdn_layers + self.conv_dim = conv_dim + self.conv_kernel_size = conv_kernel_size + self.num_v_heads = num_v_heads + self.head_k_dim = head_k_dim + self.head_v_dim = head_v_dim + + # pool_size = max_num_seqs + 1; slot 0 is the padding sentinel + pool_size = max_num_seqs + 1 + + # Conv state pool: [num_gdn_layers, pool_size, conv_dim, conv_kernel_size - 1] + conv_state_len = conv_kernel_size - 1 + self.conv_pool = paddle.zeros( + [num_gdn_layers, pool_size, conv_dim, conv_state_len], + dtype=conv_dtype, + ) + + # SSM state pool: [num_gdn_layers, pool_size, num_v_heads, head_k_dim, head_v_dim] + # K-first layout matching FLA kernel native format. + # float32 for numerical stability (SSM state accumulates over many steps). + self.ssm_pool = paddle.zeros( + [num_gdn_layers, pool_size, num_v_heads, head_k_dim, head_v_dim], + dtype=paddle.float32, + ) + + conv_mem_mb = (num_gdn_layers * pool_size * conv_dim * conv_state_len * paddle.finfo(conv_dtype).bits // 8) / ( + 1024 * 1024 + ) + ssm_mem_mb = (num_gdn_layers * pool_size * num_v_heads * head_k_dim * head_v_dim * 4) / (1024 * 1024) + logger.info( + f"GDNStatePool allocated: " + f"conv_pool {list(self.conv_pool.shape)} ({conv_mem_mb:.1f} MB), " + f"ssm_pool {list(self.ssm_pool.shape)} ({ssm_mem_mb:.1f} MB)" + ) + + # Free slot list: valid slots are 1..max_num_seqs (slot 0 is sentinel) + self._free_slots: List[int] = list(range(max_num_seqs, 0, -1)) + + logger.info( + f"GDNStatePool: {len(self._free_slots)} free slots available " f"(slot 0 reserved as padding sentinel)" + ) + + def allocate(self, n: int = 1) -> List[int]: + """Allocate n slot IDs from the free list. + + Args: + n: Number of slots to allocate. + + Returns: + List of allocated slot IDs (1-based, already offset for pool indexing). + + Raises: + RuntimeError: If not enough free slots available. + """ + if len(self._free_slots) < n: + raise RuntimeError(f"GDNStatePool: cannot allocate {n} slots, " f"only {len(self._free_slots)} free") + allocated = [self._free_slots.pop() for _ in range(n)] + return allocated + + def free(self, slot_ids: List[int]): + """Return slot IDs to the free list and zero-out their state. + + Args: + slot_ids: List of slot IDs to free (1-based pool indices). + Slot 0 (padding sentinel) is silently ignored. + """ + valid = [s for s in slot_ids if s > 0] + if not valid: + return + self.reset_slots(valid) + self._free_slots.extend(valid) + + @property + def num_free_slots(self) -> int: + """Number of currently available slots.""" + return len(self._free_slots) + + def get_layer_conv_pool(self, layer_idx: int) -> paddle.Tensor: + """Get conv state pool for a specific GDN layer. + + Returns: + Tensor of shape [pool_size, conv_dim, conv_kernel_size - 1] + """ + return self.conv_pool[layer_idx] + + def get_layer_ssm_pool(self, layer_idx: int) -> paddle.Tensor: + """Get SSM state pool for a specific GDN layer. + + Returns: + Tensor of shape [pool_size, num_v_heads, head_k_dim, head_v_dim] + """ + return self.ssm_pool[layer_idx] + + def reset_slots(self, slot_ids: List[int]): + """Zero-out conv and SSM state for given slots across all layers. + + Used when requests finish and their slots are returned to the free list. + + Args: + slot_ids: List of slot indices to reset (already +1 offset applied). + """ + if not slot_ids: + return + idx = paddle.to_tensor(slot_ids, dtype=paddle.int64) + for layer_idx in range(self.num_gdn_layers): + self.conv_pool[layer_idx][idx] = 0 + self.ssm_pool[layer_idx][idx] = 0 + + @staticmethod + def offset_slot_ids(raw_slot_ids: paddle.Tensor) -> paddle.Tensor: + """Apply +1 offset to raw slot IDs so PAD_SLOT_ID (-1) maps to slot 0. + + Args: + raw_slot_ids: [batch_size] int32, may contain PAD_SLOT_ID (-1). + + Returns: + Offset slot IDs where -1 -> 0, 0 -> 1, 1 -> 2, etc. + """ + return raw_slot_ids + 1 diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 391e2038534..ecf8a3f780c 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -188,6 +188,9 @@ def __init__( self.task_type = RequestType.PREFILL self.has_been_preempted_before = False self.idx = None + # GDN (Gated Delta Network) SSM state slot ID in GDNStatePool. + # Allocated by ResourceManagerV1 during schedule(), freed in _free_blocks(). + self.gdn_slot_id = None self.need_prefill_tokens = self.prompt_token_ids_len self.audio_output_token_ids = [] # extend block tables diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index b0425d779d1..d382d41d0d2 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -28,6 +28,7 @@ import paddle from fastdeploy import envs +from fastdeploy.cache_manager.gdn_state_pool import GDNSlotAllocator from fastdeploy.cache_manager.multimodal_cache_manager import ( EncoderCacheManager, ProcessorCacheManager, @@ -231,6 +232,32 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l # Scheduler-side requests that have not been moved into resource manager waiting queue yet. self.scheduler_unhandled_request_num = 0 + # GDN SSM slot allocator (None if model has no GDN layers) + self.gdn_slot_allocator = self._init_gdn_slot_allocator() + + def _init_gdn_slot_allocator(self): + """Create GDN slot allocator if the model has GDN (linear_attention) layers.""" + model_config = self.config.model_config + layer_types = getattr(model_config, "layer_types", None) + if layer_types is None: + # Generate from full_attention_interval if not explicit + interval = getattr(model_config, "full_attention_interval", None) + if interval is None: + return None + num_layers = model_config.num_hidden_layers + layer_types = [ + "linear_attention" if (i + 1) % interval != 0 else "full_attention" for i in range(num_layers) + ] + num_gdn_layers = sum(1 for lt in layer_types if lt == "linear_attention") + if num_gdn_layers == 0: + return None + allocator = GDNSlotAllocator(self.config.scheduler_config.max_num_seqs) + llm_logger.info( + f"GDN slot allocator initialized: {num_gdn_layers} GDN layers, " + f"max_num_seqs={self.config.scheduler_config.max_num_seqs}" + ) + return allocator + def allocated_slots(self, request: Request): return len(request.block_tables) * self.config.cache_config.block_size @@ -997,6 +1024,9 @@ def _allocate_decode_and_extend(): request.block_tables.extend(extra_gpu_block_ids) self.waiting.popleft() self.running.append(request) + # Allocate GDN SSM state slot if model has GDN layers + if self.gdn_slot_allocator is not None and request.gdn_slot_id is None: + request.gdn_slot_id = self.gdn_slot_allocator.allocate() scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens @@ -1062,6 +1092,9 @@ def _allocate_decode_and_extend(): request.block_tables.extend(extra_gpu_block_ids) self.waiting.popleft() self.running.append(request) + # Re-allocate GDN slot for preempted request + if self.gdn_slot_allocator is not None and request.gdn_slot_id is None: + request.gdn_slot_id = self.gdn_slot_allocator.allocate() scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens @@ -1452,6 +1485,11 @@ def _free_blocks(self, request: Request): self.cache_manager.recycle_gpu_blocks(request.block_tables, request.request_id) request.block_tables = [] + # Free GDN SSM state slot + if self.gdn_slot_allocator is not None and request.gdn_slot_id is not None: + self.gdn_slot_allocator.free(request.gdn_slot_id) + request.gdn_slot_id = None + if request.request_id in self.using_extend_tables_req_id: reuse_block_num = self.reuse_block_num_map[request.request_id] diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 9e512f32355..27028b65213 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -17,7 +17,7 @@ import logging from dataclasses import dataclass, fields from enum import IntEnum, auto -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import paddle @@ -160,6 +160,21 @@ class ForwardMeta: position_ids: Optional[paddle.Tensor] = None + # ============================================================ + # GDN (Gated Delta Network) linear attention fields + # ============================================================ + # GDN state pool object (shared across all GDN layers) + gdn_state_pool: Optional[Any] = None + # Slot indices into the GDN state pool [batch_size], int32. + # PAD_SLOT_ID=-1 requests are offset to slot 0 (zero-filled sentinel). + gdn_slot_ids: Optional[paddle.Tensor] = None + # Whether each request has prior state (False = new request) [batch_size], bool + gdn_has_initial_state: Optional[paddle.Tensor] = None + # CPU sequence lengths for causal_conv1d_fn varlen [batch_size], int32 + gdn_seq_lens_cpu: Optional[List[int]] = None + # GDN attention backend (GDNAttentionBackend instance, inherits AttentionBackend) + gdn_attn_backend: Optional[Any] = None + def clear_caches(self): """Safely clean up the caches""" if self.caches: diff --git a/fastdeploy/model_executor/layers/attention/__init__.py b/fastdeploy/model_executor/layers/attention/__init__.py index 7efc3259fbc..c75f7e4de18 100644 --- a/fastdeploy/model_executor/layers/attention/__init__.py +++ b/fastdeploy/model_executor/layers/attention/__init__.py @@ -20,6 +20,7 @@ from .dsa_attention_backend import DSAAttentionBackend from .flash_attn_backend import FlashAttentionBackend from .flash_mask_attn_backend import FlashMaskAttentionBackend +from .gdn_attention import GDNAttention from .mla_attention_backend import MLAAttentionBackend from .moba_attention_backend import PlasAttentionBackend from .native_paddle_backend import PaddleNativeAttnBackend @@ -34,6 +35,7 @@ "FlashAttentionBackend", "BlockAttentionBackend", "Attention", + "GDNAttention", "PlasAttentionBackend", "FlashMaskAttentionBackend", ] diff --git a/fastdeploy/model_executor/layers/attention/gdn_attention.py b/fastdeploy/model_executor/layers/attention/gdn_attention.py new file mode 100644 index 00000000000..fa590f9e160 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/gdn_attention.py @@ -0,0 +1,83 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GDN linear attention trampoline layer. + +Analogous to ``Attention`` for standard softmax attention — instantiated in the +model layer's ``__init__`` and called in ``forward``, delegating to the backend +on ``forward_meta``. + +Usage:: + + class Qwen3_5GatedDeltaNet(nn.Layer): + def __init__(self, fd_config, layer_id, ...): + ... + self.gdn_attn = GDNAttention(fd_config, layer_id) + + def forward(self, forward_meta, hidden_states): + ... + core_attn_out = self.gdn_attn(mixed_qkv, a, b, self, forward_meta) +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import paddle +from paddle import nn + +from fastdeploy.config import FDConfig + +if TYPE_CHECKING: + from fastdeploy.model_executor.forward_meta import ForwardMeta + + +class GDNAttention(nn.Layer): + """GDN (Gated Delta Network) linear attention trampoline. + + Mirrors the role of :class:`Attention` for softmax attention: + the model layer holds ``self.gdn_attn = GDNAttention(fd_config, layer_id)`` + and calls ``self.gdn_attn(mixed_qkv, a, b, self, forward_meta)`` in its + forward. + + Internally delegates to ``forward_meta.gdn_attn_backend.forward()``. + """ + + def __init__(self, fd_config: FDConfig, layer_id: int) -> None: + super().__init__() + self.fd_config = fd_config + self.layer_id = layer_id + + def forward( + self, + mixed_qkv: paddle.Tensor, + a: paddle.Tensor, + b: paddle.Tensor, + layer: nn.Layer, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + """Forward pass — delegates to the GDN attention backend. + + Args: + mixed_qkv: Projected QKV tensor ``[num_tokens, conv_dim]``. + a: Gating input a ``[num_tokens, num_v_heads]``. + b: Gating input b ``[num_tokens, num_v_heads]``. + layer: The parent ``Qwen3_5GatedDeltaNet`` instance (provides + ``conv_weight``, ``A_log_local``, ``dt_bias_local``, etc.). + forward_meta: Per-step forward metadata. + + Returns: + Attention output ``[num_tokens, num_v_heads_local, head_v_dim]``. + """ + return forward_meta.gdn_attn_backend.forward(mixed_qkv, a, b, layer, forward_meta) diff --git a/fastdeploy/model_executor/layers/attention/gdn_backend.py b/fastdeploy/model_executor/layers/attention/gdn_backend.py new file mode 100644 index 00000000000..2edf1653c71 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/gdn_backend.py @@ -0,0 +1,337 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +GDN (Gated Delta Network) Attention Backend for Qwen3.5 linear attention. + +Architecture (inspired by SGLang's gdn_backend.py): + - GDNKernelDispatcher: strategy pattern routing to Triton (FLA) kernels + - GDNAttentionBackend(AttentionBackend): unified forward entry for model layer + +Call chain: + Model Layer: Qwen3_5GatedDeltaNet.forward() + |-- projections (qkv, z, b, a) + |-- self.gdn_attn(mixed_qkv, a, b, self, forward_meta) [GDNAttention trampoline] + |-- forward_meta.gdn_attn_backend.forward(mixed_qkv, a, b, layer, forward_meta) + | + GDNAttentionBackend.forward() + |-- causal_conv1d (decode: triton update / prefill: triton varlen) + |-- split Q,K,V + fused_gdn_gating + |-- GVA repeat (if needed) + |-- kernel_dispatcher.decode/extend(...) + | + GDNKernelDispatcher + |-- Triton FLA kernels (decode: fused_recurrent, extend: chunk) + |-- gated RMSNorm + output projection +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import paddle + +from fastdeploy.cache_manager.gdn_state_pool import GDNStatePool +from fastdeploy.model_executor.layers.attention.base_attention_backend import ( + AttentionBackend, +) +from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import ( + causal_conv1d_fn, + causal_conv1d_update, +) +from fastdeploy.model_executor.ops.triton_ops.fla import ( + chunk_gated_delta_rule, + fused_gdn_gating, + fused_recurrent_gated_delta_rule_update, +) + +if TYPE_CHECKING: + from fastdeploy.model_executor.forward_meta import ForwardMeta + +logger = logging.getLogger(__name__) + +# ============================================================================== +# Kernel Dispatcher (strategy pattern, inspired by SGLang GDNKernelDispatcher) +# ============================================================================== + + +class GDNKernelDispatcher: + """Strategy pattern — routes SSM kernel calls to Triton (FLA) kernels. + + Currently: Triton (FLA) kernels only. + Future: add FlashInfer/CUDA by selecting different kernels at construction time. + """ + + def decode( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + *, + ssm_pool: paddle.Tensor, + slot_ids: paddle.Tensor, + ) -> paddle.Tensor: + """Decode: fused recurrent kernel (pool-indexed, in-place). + + Args: + q,k: [batch, 1, H, DK] + v: [batch, 1, H, DV] + g,beta: [batch, 1, H] + ssm_pool: [pool_size, H, K, V] + slot_ids: [batch] int32 (already offset: PAD→0) + Returns: + [batch, 1, H, DV] + """ + return fused_recurrent_gated_delta_rule_update( + q=q, + k=k, + v=v, + g=g, + beta=beta, + ssm_pool=ssm_pool, + ssm_indices=slot_ids, + use_qk_l2norm_in_kernel=True, + ) + + def extend( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + *, + ssm_pool: paddle.Tensor, + slot_ids: paddle.Tensor, + cu_seqlens: paddle.Tensor, + ) -> paddle.Tensor: + """Prefill/extend: chunk kernel + state writeback. + + Args: + q,k: [1, total_tokens, H, DK] + v: [1, total_tokens, H, DV] + g,beta: [1, total_tokens, H] + ssm_pool: [pool_size, H, K, V] + slot_ids: [batch] int32 (already offset: PAD→0) + cu_seqlens: [batch+1] int32 + Returns: + [1, total_tokens, H, DV] + """ + # Clone initial state from pool — chunk kernel updates in-place + initial_state = ssm_pool[slot_ids].clone() # [batch, H, K, V] + + o, _h = chunk_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=None, + initial_state=initial_state, + initial_state_indices=paddle.arange(slot_ids.shape[0], dtype=paddle.int32), + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True, + ) + + # Write back final states (updated in-place by the kernel) + batch_size = slot_ids.shape[0] + for i in range(batch_size): + sid = int(slot_ids[i]) + if sid > 0: # skip padding sentinel (slot 0) + ssm_pool[sid] = initial_state[i] + + return o + + +# ============================================================================== +# GDN Attention Backend +# ============================================================================== + + +class GDNAttentionBackend(AttentionBackend): + """GDN (Gated Delta Network) linear attention backend. + + Inherits AttentionBackend for formal consistency with FD/vLLM/SGLang. + The model layer calls through GDNAttention trampoline: + self.gdn_attn(mixed_qkv, a, b, self, forward_meta) + → forward_meta.gdn_attn_backend.forward(...) + + Internal flow: + 1. Causal Conv1d (decode: triton update / prefill: triton varlen) + 2. Split Q,K,V + fused_gdn_gating + 3. GVA repeat (if num_v_heads > num_k_heads) + 4. SSM kernel via GDNKernelDispatcher (decode / extend) + """ + + def __init__(self): + self.kernel_dispatcher = GDNKernelDispatcher() + + def init_attention_metadata(self, forward_meta: ForwardMeta): + """GDN does not need standard attention metadata.""" + pass + + def forward( + self, + mixed_qkv: paddle.Tensor, + a: paddle.Tensor, + b: paddle.Tensor, + layer, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + """Unified forward entry — model layer calls this single method. + + Args: + mixed_qkv: [num_tokens, conv_dim] — projected QKV (after in_proj_qkv) + a: [num_tokens, num_v_heads] — alpha gating input + b: [num_tokens, num_v_heads] — beta gating input + layer: Qwen3_5GatedDeltaNet instance (provides conv_weight, A_log, dt_bias, dims) + forward_meta: ForwardMeta (contains pool, slot_ids, forward_mode, etc.) + + Returns: + [num_tokens, num_v_heads_local, head_v_dim] + """ + num_tokens = mixed_qkv.shape[0] + is_decode = forward_meta.forward_mode.is_decode() + + # Get pool views for this layer + gdn_pool = forward_meta.gdn_state_pool + raw_slot_ids = forward_meta.gdn_slot_ids # [active_bs] — already truncated + + conv_pool = gdn_pool.get_layer_conv_pool(layer.gdn_layer_idx) + ssm_pool = gdn_pool.get_layer_ssm_pool(layer.gdn_layer_idx) + + # Offset slot_ids: PAD_SLOT_ID (-1) → slot 0 + slot_ids = GDNStatePool.offset_slot_ids(raw_slot_ids) + batch_size = slot_ids.shape[0] + + # ============================================================ + # 1. Causal Conv1d + # ============================================================ + conv_weight_local = layer.conv1d_weight[: layer.conv_dim] + + if is_decode: + mixed_qkv = causal_conv1d_update( + x=mixed_qkv, + conv_state=conv_pool, + weight=conv_weight_local, + bias=None, + activation="silu", + conv_state_indices=slot_ids, + ) + else: + # Slice cu_seqlens_q to active batch (it's a shared full buffer) + cu_seqlens = forward_meta.cu_seqlens_q[: batch_size + 1] + mixed_qkv = causal_conv1d_fn( + x=mixed_qkv.T, # [dim, total_tokens] + weight=conv_weight_local, + bias=None, + conv_states=conv_pool, + query_start_loc=cu_seqlens, + seq_lens_cpu=forward_meta.gdn_seq_lens_cpu, + cache_indices=slot_ids, + has_initial_state=forward_meta.gdn_has_initial_state, + activation="silu", + ).T # [total_tokens, dim] + + # ============================================================ + # 2. Split Q, K, V + # ============================================================ + key_dim_local = layer.num_k_heads_local * layer.head_k_dim + value_dim_local = layer.num_v_heads_local * layer.head_v_dim + + q, k, v = paddle.split( + mixed_qkv, + [ + key_dim_local, + key_dim_local, + value_dim_local, + ], + axis=-1, + ) + + q = q.reshape([num_tokens, layer.num_k_heads_local, layer.head_k_dim]) + k = k.reshape([num_tokens, layer.num_k_heads_local, layer.head_k_dim]) + v = v.reshape([num_tokens, layer.num_v_heads_local, layer.head_v_dim]) + + # ============================================================ + # 3. GDN Gating + # ============================================================ + A_log_local = layer.A_log[: layer.num_v_heads_local] + dt_bias_local = layer.dt_bias[: layer.num_v_heads_local] + a_local = a[:, : layer.num_v_heads_local] + b_local = b[:, : layer.num_v_heads_local] + + g, beta = fused_gdn_gating(A_log_local, a_local, b_local, dt_bias_local) + + # ============================================================ + # 4. GVA repeat (if num_v_heads > num_k_heads) + # ============================================================ + if layer.num_v_heads_per_k_head > 1: + q = ( + q.unsqueeze(2) + .expand([num_tokens, layer.num_k_heads_local, layer.num_v_heads_per_k_head, layer.head_k_dim]) + .reshape([num_tokens, layer.num_v_heads_local, layer.head_k_dim]) + ) + k = ( + k.unsqueeze(2) + .expand([num_tokens, layer.num_k_heads_local, layer.num_v_heads_per_k_head, layer.head_k_dim]) + .reshape([num_tokens, layer.num_v_heads_local, layer.head_k_dim]) + ) + + # ============================================================ + # 5. Core SSM Attention (via dispatcher) + # ============================================================ + if is_decode: + # Decode: fused recurrent (Triton) + q_4d = q.unsqueeze(1) # [batch, 1, H, K] + k_4d = k.unsqueeze(1) + v_4d = v.unsqueeze(1) + g_4d = g.unsqueeze(1) # [batch, 1, H] + beta_4d = beta.unsqueeze(1) + + o = self.kernel_dispatcher.decode( + q_4d, + k_4d, + v_4d, + g_4d, + beta_4d, + ssm_pool=ssm_pool, + slot_ids=slot_ids, + ) + core_attn_out = o.squeeze(1) # [batch, H, V] + + else: + # Prefill/extend: chunk (Triton) + cu_seqlens = forward_meta.cu_seqlens_q[: batch_size + 1] + q_4d = q.unsqueeze(0) # [1, total_tokens, H, K] + k_4d = k.unsqueeze(0) + v_4d = v.unsqueeze(0) + g_4d = g.unsqueeze(0) # [1, total_tokens, H] + beta_4d = beta.unsqueeze(0) + + o = self.kernel_dispatcher.extend( + q_4d, + k_4d, + v_4d, + g_4d, + beta_4d, + ssm_pool=ssm_pool, + slot_ids=slot_ids, + cu_seqlens=cu_seqlens, + ) + core_attn_out = o.squeeze(0) # [total_tokens, H, V] + + return core_attn_out diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index bed61bd5b1d..a3cf906e83c 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -841,6 +841,9 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) + # Write GDN slot ID for this request + if hasattr(request, "gdn_slot_id") and request.gdn_slot_id is not None: + self.share_inputs["gdn_slot_ids"][idx : idx + 1] = request.gdn_slot_id self.share_inputs["stop_flags"][idx : idx + 1] = False self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length @@ -905,6 +908,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = logger.info(f"Handle abort request {request} at idx {idx}") self.share_inputs["preempted_idx"][idx : idx + 1, :] = 1 self.share_inputs["block_tables"][idx : idx + 1, :] = -1 + # Reset GDN slot ID to PAD + self.share_inputs["gdn_slot_ids"][idx : idx + 1] = -1 self.share_inputs["stop_flags"][idx : idx + 1] = True self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = 0 self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 @@ -1349,6 +1354,38 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): self.forward_meta.is_zero_size = self.forward_meta.ids_remove_padding.shape[0] == 0 self.forward_meta.exist_prefill = self.exist_prefill() + # Populate GDN state pool fields if available + if getattr(self, "gdn_state_pool", None) is not None: + self.forward_meta.gdn_state_pool = self.gdn_state_pool + + # Compute active batch size: requests with seq_lens_this_time > 0. + # share_inputs tensors are pre-allocated to [max_num_seqs], but GDN + # kernels (causal_conv1d_fn, chunk_gated_delta_rule) need tensors + # sized to the actual active batch — unlike AppendAttention which + # has its own metadata to handle the full buffer. + gdn_active_bs = int((self.share_inputs["seq_lens_this_time"] > 0).sum()) + + # gdn_slot_ids: from share_inputs, truncated to active batch + raw_slot_ids = self.share_inputs.get("gdn_slot_ids") + if raw_slot_ids is not None: + self.forward_meta.gdn_slot_ids = raw_slot_ids[:gdn_active_bs] + + # gdn_has_initial_state: True if request has prior state (seq_lens_decoder > 0) + # For prefill of a new request, seq_lens_decoder=0 → has_initial_state=False + # For decode or chunked prefill continuation, seq_lens_decoder>0 → has_initial_state=True + if self.forward_meta.seq_lens_decoder is not None: + self.forward_meta.gdn_has_initial_state = self.forward_meta.seq_lens_decoder[:gdn_active_bs] > 0 + + # Derive gdn_seq_lens_cpu from seq_lens_this_time, truncated to active batch + if self.forward_meta.seq_lens_this_time is not None: + self.forward_meta.gdn_seq_lens_cpu = ( + self.forward_meta.seq_lens_this_time[:gdn_active_bs].numpy().tolist() + ) + + # GDN attention backend + if getattr(self, "gdn_attn_backend", None) is not None: + self.forward_meta.gdn_attn_backend = self.gdn_attn_backend + def initialize_kv_cache(self, profile: bool = False) -> None: """ Initialize kv cache @@ -1499,6 +1536,72 @@ def initialize_kv_cache(self, profile: bool = False) -> None: paddle.device.cuda.empty_cache() logger.info("kv cache is initialized!") + # Initialize GDN state pool if the model has GDN layers + self._initialize_gdn_state_pool() + + def _initialize_gdn_state_pool(self) -> None: + """Initialize GDN (Gated Delta Network) state pool if the model uses GDN layers.""" + config = self.model_config + + # Detect GDN layers by checking for layer_types config + layer_types = getattr(config, "layer_types", None) + if layer_types is None: + # Generate from full_attention_interval if not explicit + interval = getattr(config, "full_attention_interval", None) + if interval is None: + self.gdn_state_pool = None + return + num_layers = config.num_hidden_layers + layer_types = [ + "linear_attention" if (i + 1) % interval != 0 else "full_attention" for i in range(num_layers) + ] + + num_gdn_layers = sum(1 for lt in layer_types if lt == "linear_attention") + if num_gdn_layers == 0: + self.gdn_state_pool = None + return + + # GDN-specific dimensions + head_k_dim = getattr(config, "linear_key_head_dim", 128) + head_v_dim = getattr(config, "linear_value_head_dim", 128) + num_k_heads = getattr(config, "linear_num_key_heads", 16) + num_v_heads = getattr(config, "linear_num_value_heads", 16) + conv_kernel_size = getattr(config, "linear_conv_kernel_dim", 4) + + tp_size = self.parallel_config.tensor_parallel_size + + key_dim = head_k_dim * num_k_heads + value_dim = head_v_dim * num_v_heads + conv_dim_local = (key_dim * 2 + value_dim) // tp_size + num_v_heads_local = num_v_heads // tp_size + + max_num_seqs = self.scheduler_config.max_num_seqs + + from fastdeploy.cache_manager.gdn_state_pool import GDNStatePool + + self.gdn_state_pool = GDNStatePool( + max_num_seqs=max_num_seqs, + num_gdn_layers=num_gdn_layers, + conv_dim=conv_dim_local, + conv_kernel_size=conv_kernel_size, + num_v_heads=num_v_heads_local, + head_k_dim=head_k_dim, + head_v_dim=head_v_dim, + ) + logger.info( + f"GDN state pool initialized: {num_gdn_layers} GDN layers, " + f"max_num_seqs={max_num_seqs}, conv_dim_local={conv_dim_local}, " + f"num_v_heads_local={num_v_heads_local}" + ) + + # Initialize GDN attention backend (kernel dispatcher + unified forward) + from fastdeploy.model_executor.layers.attention.gdn_backend import ( + GDNAttentionBackend, + ) + + self.gdn_attn_backend = GDNAttentionBackend() + logger.info("GDN attention backend initialized (GDNKernelDispatcher + GDNAttentionBackend)") + def _initialize_attn_backend(self) -> None: """ Initialize attention backends diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 56858aae516..09d48fc9dff 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -232,6 +232,9 @@ def init_share_inputs(self): ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num self.block_tables = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32") + # GDN SSM state pool slot IDs (PAD_SLOT_ID=-1 for empty slots) + self.gdn_slot_ids = paddle.full([max_num_seqs], -1, dtype="int32") + # Initialize free list free_list = list( range( diff --git a/tests/cache_manager/test_gdn_state_pool.py b/tests/cache_manager/test_gdn_state_pool.py new file mode 100644 index 00000000000..61477227792 --- /dev/null +++ b/tests/cache_manager/test_gdn_state_pool.py @@ -0,0 +1,332 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit tests for GDNStatePool. + +Tests cover: + 1. Pool allocation shapes and dtypes + 2. Slot 0 padding sentinel invariant + 3. Per-layer pool view indexing + 4. reset_slots zeroing across all layers + 5. offset_slot_ids PAD_SLOT_ID mapping + 6. Read/write round-trip per slot + 7. Multi-layer independence +""" + +import unittest + +import paddle + +from fastdeploy.cache_manager.gdn_state_pool import ( + PAD_SLOT_ID, + GDNSlotAllocator, + GDNStatePool, +) + + +class TestGDNStatePool(unittest.TestCase): + """Tests for GDNStatePool construction and basic operations.""" + + # Typical Qwen3.5 config (small scale for testing) + MAX_NUM_SEQS = 8 + NUM_GDN_LAYERS = 4 + CONV_DIM = 64 # e.g. (key_dim*2 + value_dim) // tp_size + CONV_KERNEL_SIZE = 4 + NUM_V_HEADS = 4 # TP-local + HEAD_K_DIM = 16 + HEAD_V_DIM = 16 + + def setUp(self): + self.pool = GDNStatePool( + max_num_seqs=self.MAX_NUM_SEQS, + num_gdn_layers=self.NUM_GDN_LAYERS, + conv_dim=self.CONV_DIM, + conv_kernel_size=self.CONV_KERNEL_SIZE, + num_v_heads=self.NUM_V_HEADS, + head_k_dim=self.HEAD_K_DIM, + head_v_dim=self.HEAD_V_DIM, + ) + + # ---------------------------------------------------------------- + # 1. Shape and dtype checks + # ---------------------------------------------------------------- + def test_conv_pool_shape(self): + """Conv pool shape: [num_gdn_layers, pool_size, conv_dim, conv_kernel_size-1]""" + pool_size = self.MAX_NUM_SEQS + 1 + expected = [self.NUM_GDN_LAYERS, pool_size, self.CONV_DIM, self.CONV_KERNEL_SIZE - 1] + self.assertEqual(list(self.pool.conv_pool.shape), expected) + + def test_ssm_pool_shape(self): + """SSM pool shape: [num_gdn_layers, pool_size, num_v_heads, head_k_dim, head_v_dim]""" + pool_size = self.MAX_NUM_SEQS + 1 + expected = [self.NUM_GDN_LAYERS, pool_size, self.NUM_V_HEADS, self.HEAD_K_DIM, self.HEAD_V_DIM] + self.assertEqual(list(self.pool.ssm_pool.shape), expected) + + def test_conv_pool_dtype(self): + """Conv pool should be bfloat16 by default.""" + self.assertEqual(self.pool.conv_pool.dtype, paddle.bfloat16) + + def test_ssm_pool_dtype(self): + """SSM pool should be float32 for numerical stability.""" + self.assertEqual(self.pool.ssm_pool.dtype, paddle.float32) + + def test_custom_conv_dtype(self): + """Conv pool should respect custom dtype.""" + pool = GDNStatePool( + max_num_seqs=4, + num_gdn_layers=1, + conv_dim=16, + conv_kernel_size=4, + num_v_heads=2, + head_k_dim=8, + head_v_dim=8, + conv_dtype=paddle.float32, + ) + self.assertEqual(pool.conv_pool.dtype, paddle.float32) + + # ---------------------------------------------------------------- + # 2. Slot 0 padding sentinel + # ---------------------------------------------------------------- + def test_slot_zero_is_zero_after_init(self): + """Slot 0 (padding sentinel) must be all-zeros after construction.""" + for layer_idx in range(self.NUM_GDN_LAYERS): + conv_slot0 = self.pool.get_layer_conv_pool(layer_idx)[0] + ssm_slot0 = self.pool.get_layer_ssm_pool(layer_idx)[0] + self.assertTrue(paddle.all(conv_slot0 == 0).item()) + self.assertTrue(paddle.all(ssm_slot0 == 0).item()) + + def test_slot_zero_stays_zero_after_write(self): + """Writing to slot 0 should work (it's just a safety net for PAD reads).""" + # Simulate a padded write going to slot 0 + self.pool.conv_pool[0, 0] = 999.0 + # After reset_slots, slot 0 should be zero again + self.pool.reset_slots([0]) + conv_slot0 = self.pool.get_layer_conv_pool(0)[0] + self.assertTrue(paddle.all(conv_slot0 == 0).item()) + + # ---------------------------------------------------------------- + # 3. Per-layer pool view + # ---------------------------------------------------------------- + def test_get_layer_conv_pool_shape(self): + """get_layer_conv_pool returns [pool_size, conv_dim, conv_kernel_size-1].""" + pool_size = self.MAX_NUM_SEQS + 1 + view = self.pool.get_layer_conv_pool(0) + expected = [pool_size, self.CONV_DIM, self.CONV_KERNEL_SIZE - 1] + self.assertEqual(list(view.shape), expected) + + def test_get_layer_ssm_pool_shape(self): + """get_layer_ssm_pool returns [pool_size, num_v_heads, head_k_dim, head_v_dim].""" + pool_size = self.MAX_NUM_SEQS + 1 + view = self.pool.get_layer_ssm_pool(0) + expected = [pool_size, self.NUM_V_HEADS, self.HEAD_K_DIM, self.HEAD_V_DIM] + self.assertEqual(list(view.shape), expected) + + def test_layer_views_are_independent(self): + """Writing to layer 0's pool should not affect layer 1's pool.""" + self.pool.get_layer_ssm_pool(0)[1] = 42.0 + ssm_layer1_slot1 = self.pool.get_layer_ssm_pool(1)[1] + self.assertTrue(paddle.all(ssm_layer1_slot1 == 0).item()) + + # ---------------------------------------------------------------- + # 4. reset_slots + # ---------------------------------------------------------------- + def test_reset_slots_zeros_conv_and_ssm(self): + """reset_slots should zero out conv and SSM state for given slots across all layers.""" + # Write non-zero data to slots 1, 2, 3 + for layer_idx in range(self.NUM_GDN_LAYERS): + for slot in [1, 2, 3]: + self.pool.conv_pool[layer_idx, slot] = float(slot) + self.pool.ssm_pool[layer_idx, slot] = float(slot) + + # Reset slots 1 and 3 + self.pool.reset_slots([1, 3]) + + for layer_idx in range(self.NUM_GDN_LAYERS): + # Slot 1 and 3 should be zero + self.assertTrue(paddle.all(self.pool.conv_pool[layer_idx, 1] == 0).item()) + self.assertTrue(paddle.all(self.pool.ssm_pool[layer_idx, 1] == 0).item()) + self.assertTrue(paddle.all(self.pool.conv_pool[layer_idx, 3] == 0).item()) + self.assertTrue(paddle.all(self.pool.ssm_pool[layer_idx, 3] == 0).item()) + # Slot 2 should still have data + self.assertTrue(paddle.all(self.pool.conv_pool[layer_idx, 2] == 2.0).item()) + self.assertTrue(paddle.all(self.pool.ssm_pool[layer_idx, 2] == 2.0).item()) + + def test_reset_slots_empty_list(self): + """reset_slots with empty list should be a no-op.""" + self.pool.ssm_pool[0, 1] = 99.0 + self.pool.reset_slots([]) + self.assertTrue(paddle.all(self.pool.ssm_pool[0, 1] == 99.0).item()) + + # ---------------------------------------------------------------- + # 5. offset_slot_ids + # ---------------------------------------------------------------- + def test_offset_slot_ids_basic(self): + """offset_slot_ids: -1->0, 0->1, 1->2, etc.""" + raw = paddle.to_tensor([-1, 0, 1, 5], dtype=paddle.int32) + offset = GDNStatePool.offset_slot_ids(raw) + expected = paddle.to_tensor([0, 1, 2, 6], dtype=paddle.int32) + self.assertTrue(paddle.all(offset == expected).item()) + + def test_offset_slot_ids_pad_maps_to_sentinel(self): + """PAD_SLOT_ID (-1) should map to slot 0 (the zero-filled sentinel).""" + raw = paddle.to_tensor([PAD_SLOT_ID], dtype=paddle.int32) + offset = GDNStatePool.offset_slot_ids(raw) + self.assertEqual(offset[0].item(), 0) + + # ---------------------------------------------------------------- + # 6. Read/write round-trip + # ---------------------------------------------------------------- + def test_ssm_read_write_roundtrip(self): + """Write a known state to a slot, read it back, verify equality.""" + layer_idx = 2 + slot_id = 5 # after +1 offset + state = paddle.randn([self.NUM_V_HEADS, self.HEAD_K_DIM, self.HEAD_V_DIM], dtype=paddle.float32) + + # Write + self.pool.get_layer_ssm_pool(layer_idx)[slot_id] = state + + # Read back + read_back = self.pool.get_layer_ssm_pool(layer_idx)[slot_id] + self.assertTrue(paddle.allclose(read_back, state).item()) + + def test_conv_read_write_roundtrip(self): + """Write a known conv state to a slot, read it back, verify equality.""" + layer_idx = 1 + slot_id = 3 + state = paddle.randn([self.CONV_DIM, self.CONV_KERNEL_SIZE - 1], dtype=paddle.bfloat16) + + self.pool.get_layer_conv_pool(layer_idx)[slot_id] = state + read_back = self.pool.get_layer_conv_pool(layer_idx)[slot_id] + self.assertTrue(paddle.allclose(read_back.cast(paddle.float32), state.cast(paddle.float32), atol=1e-2).item()) + + # ---------------------------------------------------------------- + # 7. Stored attributes + # ---------------------------------------------------------------- + def test_stored_attributes(self): + """Pool should store construction parameters for later inspection.""" + self.assertEqual(self.pool.max_num_seqs, self.MAX_NUM_SEQS) + self.assertEqual(self.pool.num_gdn_layers, self.NUM_GDN_LAYERS) + self.assertEqual(self.pool.conv_dim, self.CONV_DIM) + self.assertEqual(self.pool.conv_kernel_size, self.CONV_KERNEL_SIZE) + self.assertEqual(self.pool.num_v_heads, self.NUM_V_HEADS) + self.assertEqual(self.pool.head_k_dim, self.HEAD_K_DIM) + self.assertEqual(self.pool.head_v_dim, self.HEAD_V_DIM) + + # ---------------------------------------------------------------- + # 8. Pool allocate/free (GPU-side) + # ---------------------------------------------------------------- + def test_pool_allocate_returns_valid_slots(self): + """Pool allocate() should return 1-based slot IDs.""" + slots = self.pool.allocate(3) + self.assertEqual(len(slots), 3) + for s in slots: + self.assertGreater(s, 0) + self.assertLessEqual(s, self.MAX_NUM_SEQS) + + def test_pool_allocate_exhaustion(self): + """Pool allocate() should raise when exhausted.""" + self.pool.allocate(self.MAX_NUM_SEQS) + with self.assertRaises(RuntimeError): + self.pool.allocate(1) + + def test_pool_free_recycles_slots(self): + """Pool free() should make slots available again and zero state.""" + slots = self.pool.allocate(2) + # Write data to allocated slots + for s in slots: + self.pool.ssm_pool[0, s] = 42.0 + self.pool.free(slots) + # State should be zeroed + for s in slots: + self.assertTrue(paddle.all(self.pool.ssm_pool[0, s] == 0).item()) + # Should be able to allocate again + new_slots = self.pool.allocate(2) + self.assertEqual(len(new_slots), 2) + + def test_pool_num_free_slots(self): + """num_free_slots should track available count.""" + self.assertEqual(self.pool.num_free_slots, self.MAX_NUM_SEQS) + self.pool.allocate(3) + self.assertEqual(self.pool.num_free_slots, self.MAX_NUM_SEQS - 3) + + +class TestGDNSlotAllocator(unittest.TestCase): + """Tests for the lightweight CPU-only slot allocator.""" + + MAX_NUM_SEQS = 8 + + def setUp(self): + self.allocator = GDNSlotAllocator(self.MAX_NUM_SEQS) + + def test_allocate_returns_1_based(self): + """Allocated slots should be 1-based (slot 0 is sentinel).""" + slot = self.allocator.allocate() + self.assertGreater(slot, 0) + self.assertLessEqual(slot, self.MAX_NUM_SEQS) + + def test_allocate_unique(self): + """Each allocation should return a unique slot.""" + slots = set() + for _ in range(self.MAX_NUM_SEQS): + slots.add(self.allocator.allocate()) + self.assertEqual(len(slots), self.MAX_NUM_SEQS) + + def test_allocate_exhaustion_raises(self): + """Should raise RuntimeError when no free slots.""" + for _ in range(self.MAX_NUM_SEQS): + self.allocator.allocate() + with self.assertRaises(RuntimeError): + self.allocator.allocate() + + def test_free_makes_slot_available(self): + """Freed slot should be re-allocatable.""" + slot = self.allocator.allocate() + self.allocator.free(slot) + new_slot = self.allocator.allocate() + self.assertEqual(new_slot, slot) # LIFO: last freed = next allocated + + def test_free_slot_zero_ignored(self): + """Freeing slot 0 (sentinel) should be a no-op.""" + initial_free = self.allocator.num_free_slots + self.allocator.free(0) + self.assertEqual(self.allocator.num_free_slots, initial_free) + + def test_num_free_slots(self): + """num_free_slots tracks available count correctly.""" + self.assertEqual(self.allocator.num_free_slots, self.MAX_NUM_SEQS) + self.allocator.allocate() + self.assertEqual(self.allocator.num_free_slots, self.MAX_NUM_SEQS - 1) + self.allocator.allocate() + self.assertEqual(self.allocator.num_free_slots, self.MAX_NUM_SEQS - 2) + + def test_allocate_free_lifecycle(self): + """Simulate a full request lifecycle: allocate → use → free → re-allocate.""" + # Allocate all slots + all_slots = [self.allocator.allocate() for _ in range(self.MAX_NUM_SEQS)] + self.assertEqual(self.allocator.num_free_slots, 0) + + # Free half + for s in all_slots[:4]: + self.allocator.free(s) + self.assertEqual(self.allocator.num_free_slots, 4) + + # Re-allocate + new_slots = [self.allocator.allocate() for _ in range(4)] + self.assertEqual(self.allocator.num_free_slots, 0) + # All new_slots should be from the freed set + self.assertEqual(set(new_slots), set(all_slots[:4])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/model_executor/test_gdn_backend.py b/tests/model_executor/test_gdn_backend.py new file mode 100644 index 00000000000..417f0a91050 --- /dev/null +++ b/tests/model_executor/test_gdn_backend.py @@ -0,0 +1,546 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +GDNAttentionBackend integration tests — shape + numerical correctness. + +Tests decode / prefill (extend) / mixed mode forward pass through the full +GDN backend pipeline: conv1d → split Q,K,V → gating → SSM kernel. + +Numerical correctness is verified by comparing the backend output against a +manually-constructed reference pipeline that calls the same Triton kernels +step-by-step, ensuring the backend's data orchestration (reshape, slice, +unsqueeze, GVA repeat, etc.) is correct. + +Run: + cd FastDeploy + python -m pytest tests/model_executor/ops/triton_ops/test_gdn_backend.py -v +""" + +import unittest +from enum import Enum, auto +from types import SimpleNamespace + +import numpy as np +import paddle + +from fastdeploy.cache_manager.gdn_state_pool import GDNStatePool +from fastdeploy.model_executor.layers.attention.gdn_backend import GDNAttentionBackend +from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import ( + causal_conv1d_fn, + causal_conv1d_update, +) +from fastdeploy.model_executor.ops.triton_ops.fla import ( + chunk_gated_delta_rule, + fused_gdn_gating, + fused_recurrent_gated_delta_rule_update, +) + +# ============================================================ +# Mock ForwardMode +# ============================================================ + + +class ForwardMode(Enum): + EXTEND = auto() + DECODE = auto() + MIXED = auto() + + def is_decode(self): + return self == ForwardMode.DECODE + + def is_mixed(self): + return self == ForwardMode.MIXED + + +# ============================================================ +# Test dimensions (small for speed) +# ============================================================ + +NUM_K_HEADS = 2 +NUM_V_HEADS = 4 +HEAD_K_DIM = 16 +HEAD_V_DIM = 16 +CONV_KERNEL_SIZE = 4 +NUM_V_HEADS_PER_K_HEAD = NUM_V_HEADS // NUM_K_HEADS + +KEY_DIM_LOCAL = NUM_K_HEADS * HEAD_K_DIM +VALUE_DIM_LOCAL = NUM_V_HEADS * HEAD_V_DIM +CONV_DIM = KEY_DIM_LOCAL * 2 + VALUE_DIM_LOCAL + + +# ============================================================ +# Helpers +# ============================================================ + + +def make_pool(max_num_seqs=8, num_layers=1): + return GDNStatePool( + max_num_seqs=max_num_seqs, + num_gdn_layers=num_layers, + conv_dim=CONV_DIM, + conv_kernel_size=CONV_KERNEL_SIZE, + num_v_heads=NUM_V_HEADS, + head_k_dim=HEAD_K_DIM, + head_v_dim=HEAD_V_DIM, + ) + + +def make_layer(): + layer = SimpleNamespace() + layer.gdn_layer_idx = 0 + layer.num_k_heads_local = NUM_K_HEADS + layer.num_v_heads_local = NUM_V_HEADS + layer.head_k_dim = HEAD_K_DIM + layer.head_v_dim = HEAD_V_DIM + layer.num_v_heads_per_k_head = NUM_V_HEADS_PER_K_HEAD + layer.conv_dim = CONV_DIM + layer.conv1d_weight = paddle.randn([CONV_DIM, CONV_KERNEL_SIZE], dtype="bfloat16") + layer.A_log = paddle.randn([NUM_V_HEADS], dtype="float32") + layer.dt_bias = paddle.randn([NUM_V_HEADS], dtype="float32") + return layer + + +def make_meta_decode(batch_size, pool): + raw_slot_ids = paddle.arange(0, batch_size, dtype="int32") + meta = SimpleNamespace() + meta.forward_mode = ForwardMode.DECODE + meta.gdn_state_pool = pool + meta.gdn_slot_ids = raw_slot_ids + meta.gdn_has_initial_state = paddle.ones([batch_size], dtype="bool") + meta.gdn_seq_lens_cpu = [1] * batch_size + meta.cu_seqlens_q = paddle.arange(0, batch_size + 1, dtype="int32") + return meta + + +def make_meta_extend(batch_size, seq_lens, pool): + cu = [0] + for sl in seq_lens: + cu.append(cu[-1] + sl) + raw_slot_ids = paddle.arange(0, batch_size, dtype="int32") + meta = SimpleNamespace() + meta.forward_mode = ForwardMode.EXTEND + meta.gdn_state_pool = pool + meta.gdn_slot_ids = raw_slot_ids + meta.gdn_has_initial_state = paddle.zeros([batch_size], dtype="bool") + meta.gdn_seq_lens_cpu = seq_lens + meta.cu_seqlens_q = paddle.to_tensor(cu, dtype="int32") + return meta + + +def make_meta_mixed(num_decode, num_extend, extend_seq_lens, pool): + batch_size = num_decode + num_extend + all_seq_lens = [1] * num_decode + extend_seq_lens + cu = [0] + for sl in all_seq_lens: + cu.append(cu[-1] + sl) + raw_slot_ids = paddle.arange(0, batch_size, dtype="int32") + has_initial = [True] * num_decode + [False] * num_extend + meta = SimpleNamespace() + meta.forward_mode = ForwardMode.MIXED + meta.gdn_state_pool = pool + meta.gdn_slot_ids = raw_slot_ids + meta.gdn_has_initial_state = paddle.to_tensor(has_initial, dtype="bool") + meta.gdn_seq_lens_cpu = all_seq_lens + meta.cu_seqlens_q = paddle.to_tensor(cu, dtype="int32") + return meta + + +# ============================================================ +# Reference pipeline: manually call the same Triton kernels +# ============================================================ + + +def ref_decode_pipeline(mixed_qkv, a, b, layer, conv_pool, ssm_pool, slot_ids): + """Manual decode pipeline using Triton kernels directly.""" + conv_weight = layer.conv1d_weight[: layer.conv_dim] + + # 1. Conv1d update + x = causal_conv1d_update( + x=mixed_qkv, + conv_state=conv_pool, + weight=conv_weight, + bias=None, + activation="silu", + conv_state_indices=slot_ids, + ) + + # 2. Split Q, K, V + num_tokens = x.shape[0] + q, k, v = paddle.split( + x, + [KEY_DIM_LOCAL, KEY_DIM_LOCAL, VALUE_DIM_LOCAL], + axis=-1, + ) + q = q.reshape([num_tokens, NUM_K_HEADS, HEAD_K_DIM]) + k = k.reshape([num_tokens, NUM_K_HEADS, HEAD_K_DIM]) + v = v.reshape([num_tokens, NUM_V_HEADS, HEAD_V_DIM]) + + # 3. Gating + g, beta = fused_gdn_gating( + layer.A_log[:NUM_V_HEADS], + a[:, :NUM_V_HEADS], + b[:, :NUM_V_HEADS], + layer.dt_bias[:NUM_V_HEADS], + ) + + # 4. GVA repeat + if NUM_V_HEADS_PER_K_HEAD > 1: + q = ( + q.unsqueeze(2) + .expand([num_tokens, NUM_K_HEADS, NUM_V_HEADS_PER_K_HEAD, HEAD_K_DIM]) + .reshape([num_tokens, NUM_V_HEADS, HEAD_K_DIM]) + ) + k = ( + k.unsqueeze(2) + .expand([num_tokens, NUM_K_HEADS, NUM_V_HEADS_PER_K_HEAD, HEAD_K_DIM]) + .reshape([num_tokens, NUM_V_HEADS, HEAD_K_DIM]) + ) + + # 5. SSM kernel (decode) + q_4d = q.unsqueeze(1) + k_4d = k.unsqueeze(1) + v_4d = v.unsqueeze(1) + g_4d = g.unsqueeze(1) + beta_4d = beta.unsqueeze(1) + + o = fused_recurrent_gated_delta_rule_update( + q=q_4d, + k=k_4d, + v=v_4d, + g=g_4d, + beta=beta_4d, + ssm_pool=ssm_pool, + ssm_indices=slot_ids, + use_qk_l2norm_in_kernel=True, + ) + return o.squeeze(1) # [batch, H, V] + + +def ref_extend_pipeline( + mixed_qkv, a, b, layer, conv_pool, ssm_pool, slot_ids, cu_seqlens, seq_lens_cpu, has_initial_state +): + """Manual extend pipeline using Triton kernels directly.""" + conv_weight = layer.conv1d_weight[: layer.conv_dim] + + # 1. Conv1d fn (varlen) + x = causal_conv1d_fn( + x=mixed_qkv.T, + weight=conv_weight, + bias=None, + conv_states=conv_pool, + query_start_loc=cu_seqlens, + seq_lens_cpu=seq_lens_cpu, + cache_indices=slot_ids, + has_initial_state=has_initial_state, + activation="silu", + ).T + + # 2. Split Q, K, V + num_tokens = x.shape[0] + q, k, v = paddle.split( + x, + [KEY_DIM_LOCAL, KEY_DIM_LOCAL, VALUE_DIM_LOCAL], + axis=-1, + ) + q = q.reshape([num_tokens, NUM_K_HEADS, HEAD_K_DIM]) + k = k.reshape([num_tokens, NUM_K_HEADS, HEAD_K_DIM]) + v = v.reshape([num_tokens, NUM_V_HEADS, HEAD_V_DIM]) + + # 3. Gating + g, beta = fused_gdn_gating( + layer.A_log[:NUM_V_HEADS], + a[:, :NUM_V_HEADS], + b[:, :NUM_V_HEADS], + layer.dt_bias[:NUM_V_HEADS], + ) + + # 4. GVA repeat + if NUM_V_HEADS_PER_K_HEAD > 1: + q = ( + q.unsqueeze(2) + .expand([num_tokens, NUM_K_HEADS, NUM_V_HEADS_PER_K_HEAD, HEAD_K_DIM]) + .reshape([num_tokens, NUM_V_HEADS, HEAD_K_DIM]) + ) + k = ( + k.unsqueeze(2) + .expand([num_tokens, NUM_K_HEADS, NUM_V_HEADS_PER_K_HEAD, HEAD_K_DIM]) + .reshape([num_tokens, NUM_V_HEADS, HEAD_K_DIM]) + ) + + # 5. SSM kernel (chunk) + batch_size = slot_ids.shape[0] + q_4d = q.unsqueeze(0) + k_4d = k.unsqueeze(0) + v_4d = v.unsqueeze(0) + g_4d = g.unsqueeze(0) + beta_4d = beta.unsqueeze(0) + + initial_state = ssm_pool[slot_ids].clone() + o, _h = chunk_gated_delta_rule( + q=q_4d, + k=k_4d, + v=v_4d, + g=g_4d, + beta=beta_4d, + scale=None, + initial_state=initial_state, + initial_state_indices=paddle.arange(batch_size, dtype=paddle.int32), + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True, + ) + + # Write back states + for i in range(batch_size): + sid = int(slot_ids[i]) + if sid > 0: + ssm_pool[sid] = initial_state[i] + + return o.squeeze(0) # [total_tokens, H, V] + + +# ============================================================ +# Test Cases +# ============================================================ + + +class TestGDNBackendDecodeNumerical(unittest.TestCase): + """Decode: backend vs manual reference — numerical match.""" + + def test_decode_numerical(self): + batch_size = 3 + pool = make_pool(max_num_seqs=8) + layer = make_layer() + backend = GDNAttentionBackend() + meta = make_meta_decode(batch_size, pool) + + mixed_qkv = paddle.randn([batch_size, CONV_DIM], dtype="bfloat16") + a = paddle.randn([batch_size, NUM_V_HEADS], dtype="bfloat16") + b = paddle.randn([batch_size, NUM_V_HEADS], dtype="bfloat16") + + # Clone pool + inputs so both paths see the same initial state + pool_ref = make_pool(max_num_seqs=8) + slot_ids = GDNStatePool.offset_slot_ids(meta.gdn_slot_ids) + + # Reference + ref_out = ref_decode_pipeline( + mixed_qkv.clone(), + a.clone(), + b.clone(), + layer, + pool_ref.get_layer_conv_pool(0), + pool_ref.get_layer_ssm_pool(0), + slot_ids, + ) + + # Backend + out = backend.forward(mixed_qkv.clone(), a.clone(), b.clone(), layer, meta) + + self.assertEqual(out.shape, [batch_size, NUM_V_HEADS, HEAD_V_DIM]) + np.testing.assert_allclose( + out.cast("float32").numpy(), + ref_out.cast("float32").numpy(), + rtol=1e-3, + atol=1e-3, + ) + + +class TestGDNBackendExtendNumerical(unittest.TestCase): + """Extend: backend vs manual reference — numerical match.""" + + def test_extend_single_seq(self): + seq_lens = [10] + pool = make_pool(max_num_seqs=8) + layer = make_layer() + backend = GDNAttentionBackend() + meta = make_meta_extend(1, seq_lens, pool) + + total = sum(seq_lens) + mixed_qkv = paddle.randn([total, CONV_DIM], dtype="bfloat16") + a = paddle.randn([total, NUM_V_HEADS], dtype="bfloat16") + b = paddle.randn([total, NUM_V_HEADS], dtype="bfloat16") + + pool_ref = make_pool(max_num_seqs=8) + slot_ids = GDNStatePool.offset_slot_ids(meta.gdn_slot_ids) + + ref_out = ref_extend_pipeline( + mixed_qkv.clone(), + a.clone(), + b.clone(), + layer, + pool_ref.get_layer_conv_pool(0), + pool_ref.get_layer_ssm_pool(0), + slot_ids, + meta.cu_seqlens_q.clone(), + list(meta.gdn_seq_lens_cpu), + meta.gdn_has_initial_state.clone(), + ) + + out = backend.forward(mixed_qkv.clone(), a.clone(), b.clone(), layer, meta) + + self.assertEqual(out.shape, [total, NUM_V_HEADS, HEAD_V_DIM]) + np.testing.assert_allclose( + out.cast("float32").numpy(), + ref_out.cast("float32").numpy(), + rtol=1e-3, + atol=1e-3, + ) + + def test_extend_multi_seq(self): + seq_lens = [5, 8, 3] + pool = make_pool(max_num_seqs=8) + layer = make_layer() + backend = GDNAttentionBackend() + meta = make_meta_extend(3, seq_lens, pool) + + total = sum(seq_lens) + mixed_qkv = paddle.randn([total, CONV_DIM], dtype="bfloat16") + a = paddle.randn([total, NUM_V_HEADS], dtype="bfloat16") + b = paddle.randn([total, NUM_V_HEADS], dtype="bfloat16") + + pool_ref = make_pool(max_num_seqs=8) + slot_ids = GDNStatePool.offset_slot_ids(meta.gdn_slot_ids) + + ref_out = ref_extend_pipeline( + mixed_qkv.clone(), + a.clone(), + b.clone(), + layer, + pool_ref.get_layer_conv_pool(0), + pool_ref.get_layer_ssm_pool(0), + slot_ids, + meta.cu_seqlens_q.clone(), + list(meta.gdn_seq_lens_cpu), + meta.gdn_has_initial_state.clone(), + ) + + out = backend.forward(mixed_qkv.clone(), a.clone(), b.clone(), layer, meta) + + self.assertEqual(out.shape, [total, NUM_V_HEADS, HEAD_V_DIM]) + np.testing.assert_allclose( + out.cast("float32").numpy(), + ref_out.cast("float32").numpy(), + rtol=1e-3, + atol=1e-3, + ) + + +class TestGDNBackendMixedNumerical(unittest.TestCase): + """Mixed mode: backend vs manual reference — numerical match.""" + + def test_mixed_numerical(self): + """2 decode (seqlen=1) + 1 extend (seqlen=6), all through extend path.""" + num_decode = 2 + num_extend = 1 + extend_seq_lens = [6] + + pool = make_pool(max_num_seqs=8) + layer = make_layer() + backend = GDNAttentionBackend() + meta = make_meta_mixed(num_decode, num_extend, extend_seq_lens, pool) + + total = num_decode + sum(extend_seq_lens) + mixed_qkv = paddle.randn([total, CONV_DIM], dtype="bfloat16") + a = paddle.randn([total, NUM_V_HEADS], dtype="bfloat16") + b = paddle.randn([total, NUM_V_HEADS], dtype="bfloat16") + + pool_ref = make_pool(max_num_seqs=8) + slot_ids = GDNStatePool.offset_slot_ids(meta.gdn_slot_ids) + + ref_out = ref_extend_pipeline( + mixed_qkv.clone(), + a.clone(), + b.clone(), + layer, + pool_ref.get_layer_conv_pool(0), + pool_ref.get_layer_ssm_pool(0), + slot_ids, + meta.cu_seqlens_q.clone(), + list(meta.gdn_seq_lens_cpu), + meta.gdn_has_initial_state.clone(), + ) + out = backend.forward(mixed_qkv.clone(), a.clone(), b.clone(), layer, meta) + + self.assertEqual(out.shape, [total, NUM_V_HEADS, HEAD_V_DIM]) + np.testing.assert_allclose( + out.cast("float32").numpy(), + ref_out.cast("float32").numpy(), + rtol=1e-3, + atol=1e-3, + ) + + +class TestGDNBackendStateUpdate(unittest.TestCase): + """Verify SSM/conv states persist across prefill → decode.""" + + def test_prefill_then_decode_state_persists(self): + pool = make_pool(max_num_seqs=8) + layer = make_layer() + backend = GDNAttentionBackend() + + # Step 1: Prefill seqlen=5 + meta_p = make_meta_extend(1, [5], pool) + mixed_qkv = paddle.randn([5, CONV_DIM], dtype="bfloat16") + a = paddle.randn([5, NUM_V_HEADS], dtype="bfloat16") + b = paddle.randn([5, NUM_V_HEADS], dtype="bfloat16") + out1 = backend.forward(mixed_qkv, a, b, layer, meta_p) + self.assertEqual(out1.shape, [5, NUM_V_HEADS, HEAD_V_DIM]) + + # SSM state should be non-zero after prefill (slot 0 → offset 1) + ssm_state = pool.get_layer_ssm_pool(0)[1].numpy() + self.assertFalse((ssm_state == 0).all(), "SSM state should be non-zero after prefill") + + # Conv state should be non-zero after prefill + conv_state = pool.get_layer_conv_pool(0)[1].cast("float32").numpy() + self.assertFalse((conv_state == 0).all(), "Conv state should be non-zero after prefill") + + # Step 2: Decode 1 token (same slot) + meta_d = make_meta_decode(1, pool) + meta_d.gdn_has_initial_state = paddle.to_tensor([True], dtype="bool") + out2 = backend.forward( + paddle.randn([1, CONV_DIM], dtype="bfloat16"), + paddle.randn([1, NUM_V_HEADS], dtype="bfloat16"), + paddle.randn([1, NUM_V_HEADS], dtype="bfloat16"), + layer, + meta_d, + ) + self.assertEqual(out2.shape, [1, NUM_V_HEADS, HEAD_V_DIM]) + + # SSM state should have changed after decode step + ssm_state_after = pool.get_layer_ssm_pool(0)[1].numpy() + self.assertFalse( + np.allclose(ssm_state, ssm_state_after, atol=1e-10), "SSM state should change after decode step" + ) + + def test_output_not_all_zeros(self): + """Sanity: output should not be trivially zero.""" + pool = make_pool(max_num_seqs=8) + layer = make_layer() + backend = GDNAttentionBackend() + meta = make_meta_extend(1, [8], pool) + + out = backend.forward( + paddle.randn([8, CONV_DIM], dtype="bfloat16"), + paddle.randn([8, NUM_V_HEADS], dtype="bfloat16"), + paddle.randn([8, NUM_V_HEADS], dtype="bfloat16"), + layer, + meta, + ) + self.assertFalse( + (out.cast("float32").numpy() == 0).all(), + "Backend output should not be all zeros", + ) + + +if __name__ == "__main__": + unittest.main()