From ced07862f606eb47b15dbdb5abdfc6b7b3efac43 Mon Sep 17 00:00:00 2001 From: liukebin Date: Mon, 30 Mar 2026 14:12:25 +0800 Subject: [PATCH 1/6] [Feature] support gdn kv cache manager for qwen3.5 --- fastdeploy/cache_manager/gdn_state_pool.py | 229 ++++++++++++ fastdeploy/engine/request.py | 3 + .../engine/sched/resource_manager_v1.py | 38 ++ fastdeploy/model_executor/forward_meta.py | 15 +- fastdeploy/worker/gpu_model_runner.py | 77 ++++ fastdeploy/worker/input_batch.py | 3 + tests/cache_manager/test_gdn_state_pool.py | 332 ++++++++++++++++++ 7 files changed, 696 insertions(+), 1 deletion(-) create mode 100644 fastdeploy/cache_manager/gdn_state_pool.py create mode 100644 tests/cache_manager/test_gdn_state_pool.py diff --git a/fastdeploy/cache_manager/gdn_state_pool.py b/fastdeploy/cache_manager/gdn_state_pool.py new file mode 100644 index 00000000000..2c8716e4cf2 --- /dev/null +++ b/fastdeploy/cache_manager/gdn_state_pool.py @@ -0,0 +1,229 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +GDN (Gated Delta Network) State Pool — pre-allocated GPU tensor pool +for conv and SSM states used by Qwen3.5 linear attention layers. + +Design: + - Analogous to paged KV cache block pool, but each slot stores a + complete per-request state (conv state + SSM state). + - All GDN layers share a single pool object, indexed by layer_idx. + - Slot 0 is reserved as a zero-filled padding sentinel. + PAD_SLOT_ID (-1) is mapped to slot 0 via +1 offset when building + gdn_slot_ids in ForwardMeta, so reads return zero and writes are + harmless. + +Pool layouts: + conv_pool: [num_gdn_layers, pool_size, conv_dim, conv_kernel_size - 1] + ssm_pool: [num_gdn_layers, pool_size, num_v_heads, head_k_dim, head_v_dim] + + where pool_size = max_num_seqs + 1 (slot 0 = padding sentinel). +""" + +import logging +from typing import List + +import paddle + +logger = logging.getLogger(__name__) + +PAD_SLOT_ID = -1 + + +class GDNSlotAllocator: + """Lightweight CPU-only slot allocator for GDN state pool. + + Used by ResourceManagerV1 on the scheduler side to manage slot IDs + without requiring paddle/GPU access. The corresponding GPU tensors + live in GDNStatePool on the worker side. + + Slot 0 is reserved as a padding sentinel. Valid slots: 1..max_num_seqs. + """ + + def __init__(self, max_num_seqs: int): + self.max_num_seqs = max_num_seqs + self._free_slots: List[int] = list(range(max_num_seqs, 0, -1)) + + def allocate(self) -> int: + """Allocate a single slot ID. + + Returns: + Allocated slot ID (1-based). + + Raises: + RuntimeError: If no free slots available. + """ + if not self._free_slots: + raise RuntimeError(f"GDNSlotAllocator: no free slots (max_num_seqs={self.max_num_seqs})") + return self._free_slots.pop() + + def free(self, slot_id: int): + """Return a slot ID to the free list. + + Args: + slot_id: Slot ID to free (1-based). Slot 0 is silently ignored. + """ + if slot_id > 0: + self._free_slots.append(slot_id) + + @property + def num_free_slots(self) -> int: + return len(self._free_slots) + + +class GDNStatePool: + """Pre-allocated GPU tensor pool for GDN conv and SSM states. + + Args: + max_num_seqs: Maximum number of concurrent sequences. + num_gdn_layers: Number of GDN (linear_attention) layers in the model. + conv_dim: TP-local convolution dimension (key_dim * 2 + value_dim) // tp_size. + conv_kernel_size: Causal conv1d kernel width (e.g. 4). + num_v_heads: TP-local number of value heads (num_v_heads // tp_size). + head_k_dim: Per-head key dimension. + head_v_dim: Per-head value dimension. + conv_dtype: Data type for conv state pool (default: bfloat16). + """ + + def __init__( + self, + max_num_seqs: int, + num_gdn_layers: int, + conv_dim: int, + conv_kernel_size: int, + num_v_heads: int, + head_k_dim: int, + head_v_dim: int, + conv_dtype: paddle.dtype = paddle.bfloat16, + ): + self.max_num_seqs = max_num_seqs + self.num_gdn_layers = num_gdn_layers + self.conv_dim = conv_dim + self.conv_kernel_size = conv_kernel_size + self.num_v_heads = num_v_heads + self.head_k_dim = head_k_dim + self.head_v_dim = head_v_dim + + # pool_size = max_num_seqs + 1; slot 0 is the padding sentinel + pool_size = max_num_seqs + 1 + + # Conv state pool: [num_gdn_layers, pool_size, conv_dim, conv_kernel_size - 1] + conv_state_len = conv_kernel_size - 1 + self.conv_pool = paddle.zeros( + [num_gdn_layers, pool_size, conv_dim, conv_state_len], + dtype=conv_dtype, + ) + + # SSM state pool: [num_gdn_layers, pool_size, num_v_heads, head_k_dim, head_v_dim] + # K-first layout matching FLA kernel native format. + # float32 for numerical stability (SSM state accumulates over many steps). + self.ssm_pool = paddle.zeros( + [num_gdn_layers, pool_size, num_v_heads, head_k_dim, head_v_dim], + dtype=paddle.float32, + ) + + conv_mem_mb = (num_gdn_layers * pool_size * conv_dim * conv_state_len * paddle.finfo(conv_dtype).bits // 8) / ( + 1024 * 1024 + ) + ssm_mem_mb = (num_gdn_layers * pool_size * num_v_heads * head_k_dim * head_v_dim * 4) / (1024 * 1024) + logger.info( + f"GDNStatePool allocated: " + f"conv_pool {list(self.conv_pool.shape)} ({conv_mem_mb:.1f} MB), " + f"ssm_pool {list(self.ssm_pool.shape)} ({ssm_mem_mb:.1f} MB)" + ) + + # Free slot list: valid slots are 1..max_num_seqs (slot 0 is sentinel) + self._free_slots: List[int] = list(range(max_num_seqs, 0, -1)) + + logger.info( + f"GDNStatePool: {len(self._free_slots)} free slots available " f"(slot 0 reserved as padding sentinel)" + ) + + def allocate(self, n: int = 1) -> List[int]: + """Allocate n slot IDs from the free list. + + Args: + n: Number of slots to allocate. + + Returns: + List of allocated slot IDs (1-based, already offset for pool indexing). + + Raises: + RuntimeError: If not enough free slots available. + """ + if len(self._free_slots) < n: + raise RuntimeError(f"GDNStatePool: cannot allocate {n} slots, " f"only {len(self._free_slots)} free") + allocated = [self._free_slots.pop() for _ in range(n)] + return allocated + + def free(self, slot_ids: List[int]): + """Return slot IDs to the free list and zero-out their state. + + Args: + slot_ids: List of slot IDs to free (1-based pool indices). + Slot 0 (padding sentinel) is silently ignored. + """ + valid = [s for s in slot_ids if s > 0] + if not valid: + return + self.reset_slots(valid) + self._free_slots.extend(valid) + + @property + def num_free_slots(self) -> int: + """Number of currently available slots.""" + return len(self._free_slots) + + def get_layer_conv_pool(self, layer_idx: int) -> paddle.Tensor: + """Get conv state pool for a specific GDN layer. + + Returns: + Tensor of shape [pool_size, conv_dim, conv_kernel_size - 1] + """ + return self.conv_pool[layer_idx] + + def get_layer_ssm_pool(self, layer_idx: int) -> paddle.Tensor: + """Get SSM state pool for a specific GDN layer. + + Returns: + Tensor of shape [pool_size, num_v_heads, head_k_dim, head_v_dim] + """ + return self.ssm_pool[layer_idx] + + def reset_slots(self, slot_ids: List[int]): + """Zero-out conv and SSM state for given slots across all layers. + + Used when requests finish and their slots are returned to the free list. + + Args: + slot_ids: List of slot indices to reset (already +1 offset applied). + """ + if not slot_ids: + return + idx = paddle.to_tensor(slot_ids, dtype=paddle.int64) + for layer_idx in range(self.num_gdn_layers): + self.conv_pool[layer_idx][idx] = 0 + self.ssm_pool[layer_idx][idx] = 0 + + @staticmethod + def offset_slot_ids(raw_slot_ids: paddle.Tensor) -> paddle.Tensor: + """Apply +1 offset to raw slot IDs so PAD_SLOT_ID (-1) maps to slot 0. + + Args: + raw_slot_ids: [batch_size] int32, may contain PAD_SLOT_ID (-1). + + Returns: + Offset slot IDs where -1 -> 0, 0 -> 1, 1 -> 2, etc. + """ + return raw_slot_ids + 1 diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 391e2038534..ecf8a3f780c 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -188,6 +188,9 @@ def __init__( self.task_type = RequestType.PREFILL self.has_been_preempted_before = False self.idx = None + # GDN (Gated Delta Network) SSM state slot ID in GDNStatePool. + # Allocated by ResourceManagerV1 during schedule(), freed in _free_blocks(). + self.gdn_slot_id = None self.need_prefill_tokens = self.prompt_token_ids_len self.audio_output_token_ids = [] # extend block tables diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index b0425d779d1..d382d41d0d2 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -28,6 +28,7 @@ import paddle from fastdeploy import envs +from fastdeploy.cache_manager.gdn_state_pool import GDNSlotAllocator from fastdeploy.cache_manager.multimodal_cache_manager import ( EncoderCacheManager, ProcessorCacheManager, @@ -231,6 +232,32 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l # Scheduler-side requests that have not been moved into resource manager waiting queue yet. self.scheduler_unhandled_request_num = 0 + # GDN SSM slot allocator (None if model has no GDN layers) + self.gdn_slot_allocator = self._init_gdn_slot_allocator() + + def _init_gdn_slot_allocator(self): + """Create GDN slot allocator if the model has GDN (linear_attention) layers.""" + model_config = self.config.model_config + layer_types = getattr(model_config, "layer_types", None) + if layer_types is None: + # Generate from full_attention_interval if not explicit + interval = getattr(model_config, "full_attention_interval", None) + if interval is None: + return None + num_layers = model_config.num_hidden_layers + layer_types = [ + "linear_attention" if (i + 1) % interval != 0 else "full_attention" for i in range(num_layers) + ] + num_gdn_layers = sum(1 for lt in layer_types if lt == "linear_attention") + if num_gdn_layers == 0: + return None + allocator = GDNSlotAllocator(self.config.scheduler_config.max_num_seqs) + llm_logger.info( + f"GDN slot allocator initialized: {num_gdn_layers} GDN layers, " + f"max_num_seqs={self.config.scheduler_config.max_num_seqs}" + ) + return allocator + def allocated_slots(self, request: Request): return len(request.block_tables) * self.config.cache_config.block_size @@ -997,6 +1024,9 @@ def _allocate_decode_and_extend(): request.block_tables.extend(extra_gpu_block_ids) self.waiting.popleft() self.running.append(request) + # Allocate GDN SSM state slot if model has GDN layers + if self.gdn_slot_allocator is not None and request.gdn_slot_id is None: + request.gdn_slot_id = self.gdn_slot_allocator.allocate() scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens @@ -1062,6 +1092,9 @@ def _allocate_decode_and_extend(): request.block_tables.extend(extra_gpu_block_ids) self.waiting.popleft() self.running.append(request) + # Re-allocate GDN slot for preempted request + if self.gdn_slot_allocator is not None and request.gdn_slot_id is None: + request.gdn_slot_id = self.gdn_slot_allocator.allocate() scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens @@ -1452,6 +1485,11 @@ def _free_blocks(self, request: Request): self.cache_manager.recycle_gpu_blocks(request.block_tables, request.request_id) request.block_tables = [] + # Free GDN SSM state slot + if self.gdn_slot_allocator is not None and request.gdn_slot_id is not None: + self.gdn_slot_allocator.free(request.gdn_slot_id) + request.gdn_slot_id = None + if request.request_id in self.using_extend_tables_req_id: reuse_block_num = self.reuse_block_num_map[request.request_id] diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 9e512f32355..235ea2a820a 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -17,7 +17,7 @@ import logging from dataclasses import dataclass, fields from enum import IntEnum, auto -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import paddle @@ -160,6 +160,19 @@ class ForwardMeta: position_ids: Optional[paddle.Tensor] = None + # ============================================================ + # GDN (Gated Delta Network) linear attention fields + # ============================================================ + # GDN state pool object (shared across all GDN layers) + gdn_state_pool: Optional[Any] = None + # Slot indices into the GDN state pool [batch_size], int32. + # PAD_SLOT_ID=-1 requests are offset to slot 0 (zero-filled sentinel). + gdn_slot_ids: Optional[paddle.Tensor] = None + # Whether each request has prior state (False = new request) [batch_size], bool + gdn_has_initial_state: Optional[paddle.Tensor] = None + # CPU sequence lengths for causal_conv1d_fn varlen [batch_size], int32 + gdn_seq_lens_cpu: Optional[List[int]] = None + def clear_caches(self): """Safely clean up the caches""" if self.caches: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index bed61bd5b1d..7682caac64c 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -841,6 +841,9 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) + # Write GDN slot ID for this request + if hasattr(request, "gdn_slot_id") and request.gdn_slot_id is not None: + self.share_inputs["gdn_slot_ids"][idx : idx + 1] = request.gdn_slot_id self.share_inputs["stop_flags"][idx : idx + 1] = False self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length @@ -905,6 +908,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = logger.info(f"Handle abort request {request} at idx {idx}") self.share_inputs["preempted_idx"][idx : idx + 1, :] = 1 self.share_inputs["block_tables"][idx : idx + 1, :] = -1 + # Reset GDN slot ID to PAD + self.share_inputs["gdn_slot_ids"][idx : idx + 1] = -1 self.share_inputs["stop_flags"][idx : idx + 1] = True self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = 0 self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 @@ -1349,6 +1354,20 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): self.forward_meta.is_zero_size = self.forward_meta.ids_remove_padding.shape[0] == 0 self.forward_meta.exist_prefill = self.exist_prefill() + # Populate GDN state pool fields if available + if getattr(self, "gdn_state_pool", None) is not None: + self.forward_meta.gdn_state_pool = self.gdn_state_pool + # gdn_slot_ids: from share_inputs (set by scheduler via insert_tasks_v1) + self.forward_meta.gdn_slot_ids = self.share_inputs.get("gdn_slot_ids") + # gdn_has_initial_state: True if request has prior state (seq_lens_decoder > 0) + # For prefill of a new request, seq_lens_decoder=0 → has_initial_state=False + # For decode or chunked prefill continuation, seq_lens_decoder>0 → has_initial_state=True + if self.forward_meta.seq_lens_decoder is not None: + self.forward_meta.gdn_has_initial_state = self.forward_meta.seq_lens_decoder > 0 + # Derive gdn_seq_lens_cpu from seq_lens_this_time + if self.forward_meta.seq_lens_this_time is not None: + self.forward_meta.gdn_seq_lens_cpu = self.forward_meta.seq_lens_this_time.numpy().tolist() + def initialize_kv_cache(self, profile: bool = False) -> None: """ Initialize kv cache @@ -1499,6 +1518,64 @@ def initialize_kv_cache(self, profile: bool = False) -> None: paddle.device.cuda.empty_cache() logger.info("kv cache is initialized!") + # Initialize GDN state pool if the model has GDN layers + self._initialize_gdn_state_pool() + + def _initialize_gdn_state_pool(self) -> None: + """Initialize GDN (Gated Delta Network) state pool if the model uses GDN layers.""" + config = self.model_config + + # Detect GDN layers by checking for layer_types config + layer_types = getattr(config, "layer_types", None) + if layer_types is None: + # Generate from full_attention_interval if not explicit + interval = getattr(config, "full_attention_interval", None) + if interval is None: + self.gdn_state_pool = None + return + num_layers = config.num_hidden_layers + layer_types = [ + "linear_attention" if (i + 1) % interval != 0 else "full_attention" for i in range(num_layers) + ] + + num_gdn_layers = sum(1 for lt in layer_types if lt == "linear_attention") + if num_gdn_layers == 0: + self.gdn_state_pool = None + return + + # GDN-specific dimensions + head_k_dim = getattr(config, "linear_key_head_dim", 128) + head_v_dim = getattr(config, "linear_value_head_dim", 128) + num_k_heads = getattr(config, "linear_num_key_heads", 16) + num_v_heads = getattr(config, "linear_num_value_heads", 16) + conv_kernel_size = getattr(config, "linear_conv_kernel_dim", 4) + + tp_size = self.parallel_config.tensor_parallel_size + + key_dim = head_k_dim * num_k_heads + value_dim = head_v_dim * num_v_heads + conv_dim_local = (key_dim * 2 + value_dim) // tp_size + num_v_heads_local = num_v_heads // tp_size + + max_num_seqs = self.scheduler_config.max_num_seqs + + from fastdeploy.cache_manager.gdn_state_pool import GDNStatePool + + self.gdn_state_pool = GDNStatePool( + max_num_seqs=max_num_seqs, + num_gdn_layers=num_gdn_layers, + conv_dim=conv_dim_local, + conv_kernel_size=conv_kernel_size, + num_v_heads=num_v_heads_local, + head_k_dim=head_k_dim, + head_v_dim=head_v_dim, + ) + logger.info( + f"GDN state pool initialized: {num_gdn_layers} GDN layers, " + f"max_num_seqs={max_num_seqs}, conv_dim_local={conv_dim_local}, " + f"num_v_heads_local={num_v_heads_local}" + ) + def _initialize_attn_backend(self) -> None: """ Initialize attention backends diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 56858aae516..09d48fc9dff 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -232,6 +232,9 @@ def init_share_inputs(self): ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num self.block_tables = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32") + # GDN SSM state pool slot IDs (PAD_SLOT_ID=-1 for empty slots) + self.gdn_slot_ids = paddle.full([max_num_seqs], -1, dtype="int32") + # Initialize free list free_list = list( range( diff --git a/tests/cache_manager/test_gdn_state_pool.py b/tests/cache_manager/test_gdn_state_pool.py new file mode 100644 index 00000000000..61477227792 --- /dev/null +++ b/tests/cache_manager/test_gdn_state_pool.py @@ -0,0 +1,332 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit tests for GDNStatePool. + +Tests cover: + 1. Pool allocation shapes and dtypes + 2. Slot 0 padding sentinel invariant + 3. Per-layer pool view indexing + 4. reset_slots zeroing across all layers + 5. offset_slot_ids PAD_SLOT_ID mapping + 6. Read/write round-trip per slot + 7. Multi-layer independence +""" + +import unittest + +import paddle + +from fastdeploy.cache_manager.gdn_state_pool import ( + PAD_SLOT_ID, + GDNSlotAllocator, + GDNStatePool, +) + + +class TestGDNStatePool(unittest.TestCase): + """Tests for GDNStatePool construction and basic operations.""" + + # Typical Qwen3.5 config (small scale for testing) + MAX_NUM_SEQS = 8 + NUM_GDN_LAYERS = 4 + CONV_DIM = 64 # e.g. (key_dim*2 + value_dim) // tp_size + CONV_KERNEL_SIZE = 4 + NUM_V_HEADS = 4 # TP-local + HEAD_K_DIM = 16 + HEAD_V_DIM = 16 + + def setUp(self): + self.pool = GDNStatePool( + max_num_seqs=self.MAX_NUM_SEQS, + num_gdn_layers=self.NUM_GDN_LAYERS, + conv_dim=self.CONV_DIM, + conv_kernel_size=self.CONV_KERNEL_SIZE, + num_v_heads=self.NUM_V_HEADS, + head_k_dim=self.HEAD_K_DIM, + head_v_dim=self.HEAD_V_DIM, + ) + + # ---------------------------------------------------------------- + # 1. Shape and dtype checks + # ---------------------------------------------------------------- + def test_conv_pool_shape(self): + """Conv pool shape: [num_gdn_layers, pool_size, conv_dim, conv_kernel_size-1]""" + pool_size = self.MAX_NUM_SEQS + 1 + expected = [self.NUM_GDN_LAYERS, pool_size, self.CONV_DIM, self.CONV_KERNEL_SIZE - 1] + self.assertEqual(list(self.pool.conv_pool.shape), expected) + + def test_ssm_pool_shape(self): + """SSM pool shape: [num_gdn_layers, pool_size, num_v_heads, head_k_dim, head_v_dim]""" + pool_size = self.MAX_NUM_SEQS + 1 + expected = [self.NUM_GDN_LAYERS, pool_size, self.NUM_V_HEADS, self.HEAD_K_DIM, self.HEAD_V_DIM] + self.assertEqual(list(self.pool.ssm_pool.shape), expected) + + def test_conv_pool_dtype(self): + """Conv pool should be bfloat16 by default.""" + self.assertEqual(self.pool.conv_pool.dtype, paddle.bfloat16) + + def test_ssm_pool_dtype(self): + """SSM pool should be float32 for numerical stability.""" + self.assertEqual(self.pool.ssm_pool.dtype, paddle.float32) + + def test_custom_conv_dtype(self): + """Conv pool should respect custom dtype.""" + pool = GDNStatePool( + max_num_seqs=4, + num_gdn_layers=1, + conv_dim=16, + conv_kernel_size=4, + num_v_heads=2, + head_k_dim=8, + head_v_dim=8, + conv_dtype=paddle.float32, + ) + self.assertEqual(pool.conv_pool.dtype, paddle.float32) + + # ---------------------------------------------------------------- + # 2. Slot 0 padding sentinel + # ---------------------------------------------------------------- + def test_slot_zero_is_zero_after_init(self): + """Slot 0 (padding sentinel) must be all-zeros after construction.""" + for layer_idx in range(self.NUM_GDN_LAYERS): + conv_slot0 = self.pool.get_layer_conv_pool(layer_idx)[0] + ssm_slot0 = self.pool.get_layer_ssm_pool(layer_idx)[0] + self.assertTrue(paddle.all(conv_slot0 == 0).item()) + self.assertTrue(paddle.all(ssm_slot0 == 0).item()) + + def test_slot_zero_stays_zero_after_write(self): + """Writing to slot 0 should work (it's just a safety net for PAD reads).""" + # Simulate a padded write going to slot 0 + self.pool.conv_pool[0, 0] = 999.0 + # After reset_slots, slot 0 should be zero again + self.pool.reset_slots([0]) + conv_slot0 = self.pool.get_layer_conv_pool(0)[0] + self.assertTrue(paddle.all(conv_slot0 == 0).item()) + + # ---------------------------------------------------------------- + # 3. Per-layer pool view + # ---------------------------------------------------------------- + def test_get_layer_conv_pool_shape(self): + """get_layer_conv_pool returns [pool_size, conv_dim, conv_kernel_size-1].""" + pool_size = self.MAX_NUM_SEQS + 1 + view = self.pool.get_layer_conv_pool(0) + expected = [pool_size, self.CONV_DIM, self.CONV_KERNEL_SIZE - 1] + self.assertEqual(list(view.shape), expected) + + def test_get_layer_ssm_pool_shape(self): + """get_layer_ssm_pool returns [pool_size, num_v_heads, head_k_dim, head_v_dim].""" + pool_size = self.MAX_NUM_SEQS + 1 + view = self.pool.get_layer_ssm_pool(0) + expected = [pool_size, self.NUM_V_HEADS, self.HEAD_K_DIM, self.HEAD_V_DIM] + self.assertEqual(list(view.shape), expected) + + def test_layer_views_are_independent(self): + """Writing to layer 0's pool should not affect layer 1's pool.""" + self.pool.get_layer_ssm_pool(0)[1] = 42.0 + ssm_layer1_slot1 = self.pool.get_layer_ssm_pool(1)[1] + self.assertTrue(paddle.all(ssm_layer1_slot1 == 0).item()) + + # ---------------------------------------------------------------- + # 4. reset_slots + # ---------------------------------------------------------------- + def test_reset_slots_zeros_conv_and_ssm(self): + """reset_slots should zero out conv and SSM state for given slots across all layers.""" + # Write non-zero data to slots 1, 2, 3 + for layer_idx in range(self.NUM_GDN_LAYERS): + for slot in [1, 2, 3]: + self.pool.conv_pool[layer_idx, slot] = float(slot) + self.pool.ssm_pool[layer_idx, slot] = float(slot) + + # Reset slots 1 and 3 + self.pool.reset_slots([1, 3]) + + for layer_idx in range(self.NUM_GDN_LAYERS): + # Slot 1 and 3 should be zero + self.assertTrue(paddle.all(self.pool.conv_pool[layer_idx, 1] == 0).item()) + self.assertTrue(paddle.all(self.pool.ssm_pool[layer_idx, 1] == 0).item()) + self.assertTrue(paddle.all(self.pool.conv_pool[layer_idx, 3] == 0).item()) + self.assertTrue(paddle.all(self.pool.ssm_pool[layer_idx, 3] == 0).item()) + # Slot 2 should still have data + self.assertTrue(paddle.all(self.pool.conv_pool[layer_idx, 2] == 2.0).item()) + self.assertTrue(paddle.all(self.pool.ssm_pool[layer_idx, 2] == 2.0).item()) + + def test_reset_slots_empty_list(self): + """reset_slots with empty list should be a no-op.""" + self.pool.ssm_pool[0, 1] = 99.0 + self.pool.reset_slots([]) + self.assertTrue(paddle.all(self.pool.ssm_pool[0, 1] == 99.0).item()) + + # ---------------------------------------------------------------- + # 5. offset_slot_ids + # ---------------------------------------------------------------- + def test_offset_slot_ids_basic(self): + """offset_slot_ids: -1->0, 0->1, 1->2, etc.""" + raw = paddle.to_tensor([-1, 0, 1, 5], dtype=paddle.int32) + offset = GDNStatePool.offset_slot_ids(raw) + expected = paddle.to_tensor([0, 1, 2, 6], dtype=paddle.int32) + self.assertTrue(paddle.all(offset == expected).item()) + + def test_offset_slot_ids_pad_maps_to_sentinel(self): + """PAD_SLOT_ID (-1) should map to slot 0 (the zero-filled sentinel).""" + raw = paddle.to_tensor([PAD_SLOT_ID], dtype=paddle.int32) + offset = GDNStatePool.offset_slot_ids(raw) + self.assertEqual(offset[0].item(), 0) + + # ---------------------------------------------------------------- + # 6. Read/write round-trip + # ---------------------------------------------------------------- + def test_ssm_read_write_roundtrip(self): + """Write a known state to a slot, read it back, verify equality.""" + layer_idx = 2 + slot_id = 5 # after +1 offset + state = paddle.randn([self.NUM_V_HEADS, self.HEAD_K_DIM, self.HEAD_V_DIM], dtype=paddle.float32) + + # Write + self.pool.get_layer_ssm_pool(layer_idx)[slot_id] = state + + # Read back + read_back = self.pool.get_layer_ssm_pool(layer_idx)[slot_id] + self.assertTrue(paddle.allclose(read_back, state).item()) + + def test_conv_read_write_roundtrip(self): + """Write a known conv state to a slot, read it back, verify equality.""" + layer_idx = 1 + slot_id = 3 + state = paddle.randn([self.CONV_DIM, self.CONV_KERNEL_SIZE - 1], dtype=paddle.bfloat16) + + self.pool.get_layer_conv_pool(layer_idx)[slot_id] = state + read_back = self.pool.get_layer_conv_pool(layer_idx)[slot_id] + self.assertTrue(paddle.allclose(read_back.cast(paddle.float32), state.cast(paddle.float32), atol=1e-2).item()) + + # ---------------------------------------------------------------- + # 7. Stored attributes + # ---------------------------------------------------------------- + def test_stored_attributes(self): + """Pool should store construction parameters for later inspection.""" + self.assertEqual(self.pool.max_num_seqs, self.MAX_NUM_SEQS) + self.assertEqual(self.pool.num_gdn_layers, self.NUM_GDN_LAYERS) + self.assertEqual(self.pool.conv_dim, self.CONV_DIM) + self.assertEqual(self.pool.conv_kernel_size, self.CONV_KERNEL_SIZE) + self.assertEqual(self.pool.num_v_heads, self.NUM_V_HEADS) + self.assertEqual(self.pool.head_k_dim, self.HEAD_K_DIM) + self.assertEqual(self.pool.head_v_dim, self.HEAD_V_DIM) + + # ---------------------------------------------------------------- + # 8. Pool allocate/free (GPU-side) + # ---------------------------------------------------------------- + def test_pool_allocate_returns_valid_slots(self): + """Pool allocate() should return 1-based slot IDs.""" + slots = self.pool.allocate(3) + self.assertEqual(len(slots), 3) + for s in slots: + self.assertGreater(s, 0) + self.assertLessEqual(s, self.MAX_NUM_SEQS) + + def test_pool_allocate_exhaustion(self): + """Pool allocate() should raise when exhausted.""" + self.pool.allocate(self.MAX_NUM_SEQS) + with self.assertRaises(RuntimeError): + self.pool.allocate(1) + + def test_pool_free_recycles_slots(self): + """Pool free() should make slots available again and zero state.""" + slots = self.pool.allocate(2) + # Write data to allocated slots + for s in slots: + self.pool.ssm_pool[0, s] = 42.0 + self.pool.free(slots) + # State should be zeroed + for s in slots: + self.assertTrue(paddle.all(self.pool.ssm_pool[0, s] == 0).item()) + # Should be able to allocate again + new_slots = self.pool.allocate(2) + self.assertEqual(len(new_slots), 2) + + def test_pool_num_free_slots(self): + """num_free_slots should track available count.""" + self.assertEqual(self.pool.num_free_slots, self.MAX_NUM_SEQS) + self.pool.allocate(3) + self.assertEqual(self.pool.num_free_slots, self.MAX_NUM_SEQS - 3) + + +class TestGDNSlotAllocator(unittest.TestCase): + """Tests for the lightweight CPU-only slot allocator.""" + + MAX_NUM_SEQS = 8 + + def setUp(self): + self.allocator = GDNSlotAllocator(self.MAX_NUM_SEQS) + + def test_allocate_returns_1_based(self): + """Allocated slots should be 1-based (slot 0 is sentinel).""" + slot = self.allocator.allocate() + self.assertGreater(slot, 0) + self.assertLessEqual(slot, self.MAX_NUM_SEQS) + + def test_allocate_unique(self): + """Each allocation should return a unique slot.""" + slots = set() + for _ in range(self.MAX_NUM_SEQS): + slots.add(self.allocator.allocate()) + self.assertEqual(len(slots), self.MAX_NUM_SEQS) + + def test_allocate_exhaustion_raises(self): + """Should raise RuntimeError when no free slots.""" + for _ in range(self.MAX_NUM_SEQS): + self.allocator.allocate() + with self.assertRaises(RuntimeError): + self.allocator.allocate() + + def test_free_makes_slot_available(self): + """Freed slot should be re-allocatable.""" + slot = self.allocator.allocate() + self.allocator.free(slot) + new_slot = self.allocator.allocate() + self.assertEqual(new_slot, slot) # LIFO: last freed = next allocated + + def test_free_slot_zero_ignored(self): + """Freeing slot 0 (sentinel) should be a no-op.""" + initial_free = self.allocator.num_free_slots + self.allocator.free(0) + self.assertEqual(self.allocator.num_free_slots, initial_free) + + def test_num_free_slots(self): + """num_free_slots tracks available count correctly.""" + self.assertEqual(self.allocator.num_free_slots, self.MAX_NUM_SEQS) + self.allocator.allocate() + self.assertEqual(self.allocator.num_free_slots, self.MAX_NUM_SEQS - 1) + self.allocator.allocate() + self.assertEqual(self.allocator.num_free_slots, self.MAX_NUM_SEQS - 2) + + def test_allocate_free_lifecycle(self): + """Simulate a full request lifecycle: allocate → use → free → re-allocate.""" + # Allocate all slots + all_slots = [self.allocator.allocate() for _ in range(self.MAX_NUM_SEQS)] + self.assertEqual(self.allocator.num_free_slots, 0) + + # Free half + for s in all_slots[:4]: + self.allocator.free(s) + self.assertEqual(self.allocator.num_free_slots, 4) + + # Re-allocate + new_slots = [self.allocator.allocate() for _ in range(4)] + self.assertEqual(self.allocator.num_free_slots, 0) + # All new_slots should be from the freed set + self.assertEqual(set(new_slots), set(all_slots[:4])) + + +if __name__ == "__main__": + unittest.main() From 30ed7920f45faa306dc182a70cdd1c0539e4272f Mon Sep 17 00:00:00 2001 From: liukebin Date: Tue, 31 Mar 2026 11:25:17 +0800 Subject: [PATCH 2/6] [Feature] support gdn backend for qwen3.5 --- fastdeploy/model_executor/forward_meta.py | 2 + .../layers/attention/gdn_backend.py | 712 ++++++++++++++++++ fastdeploy/worker/gpu_model_runner.py | 11 + 3 files changed, 725 insertions(+) create mode 100644 fastdeploy/model_executor/layers/attention/gdn_backend.py diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 235ea2a820a..27028b65213 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -172,6 +172,8 @@ class ForwardMeta: gdn_has_initial_state: Optional[paddle.Tensor] = None # CPU sequence lengths for causal_conv1d_fn varlen [batch_size], int32 gdn_seq_lens_cpu: Optional[List[int]] = None + # GDN attention backend (GDNAttentionBackend instance, inherits AttentionBackend) + gdn_attn_backend: Optional[Any] = None def clear_caches(self): """Safely clean up the caches""" diff --git a/fastdeploy/model_executor/layers/attention/gdn_backend.py b/fastdeploy/model_executor/layers/attention/gdn_backend.py new file mode 100644 index 00000000000..5384008f66e --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/gdn_backend.py @@ -0,0 +1,712 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +GDN (Gated Delta Network) Attention Backend for Qwen3.5 linear attention. + +Architecture (inspired by SGLang's gdn_backend.py): + - GDNKernelDispatcher: strategy pattern routing to Triton / Paddle fallback kernels + - GDNAttentionBackend(AttentionBackend): unified forward entry for model layer + +Call chain: + Model Layer: Qwen3_5GatedDeltaNet.forward() + |-- projections (qkv, z, b, a) + |-- forward_meta.gdn_attn_backend.forward(mixed_qkv, a, b, layer, forward_meta) + | + GDNAttentionBackend.forward() + |-- causal_conv1d (decode / prefill / fallback) + |-- split Q,K,V + fused_gdn_gating + |-- GVA repeat (if needed) + |-- kernel_dispatcher.decode/extend/fallback(...) + | + GDNKernelDispatcher + |-- Triton FLA kernels (decode: fused_recurrent, extend: chunk) + |-- Paddle fallback (paddle_recurrent / paddle_chunk) + |-- gated RMSNorm + output projection +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Optional + +import paddle + +from fastdeploy.model_executor.layers.attention.base_attention_backend import ( + AttentionBackend, +) + +if TYPE_CHECKING: + from fastdeploy.model_executor.forward_meta import ForwardMeta + +logger = logging.getLogger(__name__) + + +# ============================================================================== +# Helper functions (extracted from qwen3_5.py) +# ============================================================================== + + +def l2norm(x: paddle.Tensor, axis: int = -1, eps: float = 1e-6) -> paddle.Tensor: + """L2 normalization, aligns with FLA library.""" + inv_norm = paddle.rsqrt((x * x).sum(axis=axis, keepdim=True) + eps) + return x * inv_norm + + +def fused_gdn_gating( + A_log: paddle.Tensor, + a: paddle.Tensor, + b: paddle.Tensor, + dt_bias: paddle.Tensor, +) -> tuple: + """Compute GDN gating values. + + Args: + A_log: [num_heads] - log of A matrix + a: [num_tokens, num_heads] - alpha values + b: [num_tokens, num_heads] - beta values + dt_bias: [num_heads] - delta-time bias + + Returns: + g: gating values, same shape as a + beta: sigmoid(b), same shape as b + """ + x = a.cast(paddle.float32) + dt_bias.cast(paddle.float32) + softplus_x = paddle.nn.functional.softplus(x) + g = -paddle.exp(A_log.cast(paddle.float32)) * softplus_x + beta = paddle.nn.functional.sigmoid(b.cast(paddle.float32)) + return g, beta + + +def _causal_conv1d_single_seq( + x: paddle.Tensor, + conv_weights: paddle.Tensor, + bias: Optional[paddle.Tensor], + activation: str, + kernel_size: int, +) -> paddle.Tensor: + """Apply causal conv1d to a single sequence (pure Paddle fallback). + + Args: + x: [seq_len, channels] + conv_weights: [channels, kernel_size] + Returns: + [seq_len, channels] + """ + seq_len, channels = x.shape + x = x.transpose([1, 0]).unsqueeze(0) # [1, channels, seq_len] + weight = conv_weights.unsqueeze(1) # [channels, 1, kernel_size] + padding = kernel_size - 1 + x = paddle.nn.functional.conv1d(x, weight, bias, padding=padding, groups=channels) + x = x[:, :, :seq_len] + x = x.squeeze(0).transpose([1, 0]) # [seq_len, channels] + if activation == "silu": + x = paddle.nn.functional.silu(x) + return x + + +def _causal_conv1d_fn_fallback( + x: paddle.Tensor, + conv_weights: paddle.Tensor, + bias: Optional[paddle.Tensor] = None, + activation: str = "silu", + cu_seqlens: Optional[paddle.Tensor] = None, +) -> paddle.Tensor: + """Causal conv1d for packed sequences (pure Paddle fallback). + + Args: + x: [num_tokens, channels] + conv_weights: [channels, kernel_size] + cu_seqlens: [batch_size + 1] cumulative sequence lengths + Returns: + [num_tokens, channels] + """ + kernel_size = conv_weights.shape[-1] + if cu_seqlens is None: + return _causal_conv1d_single_seq(x, conv_weights, bias, activation, kernel_size) + + batch_size = cu_seqlens.shape[0] - 1 + cu_seqlens_np = cu_seqlens.numpy() + output = paddle.zeros_like(x) + + for i in range(batch_size): + start = int(cu_seqlens_np[i]) + end = int(cu_seqlens_np[i + 1]) + if end <= start: + continue + seq_x = x[start:end, :] + seq_out = _causal_conv1d_single_seq(seq_x, conv_weights, bias, activation, kernel_size) + output[start:end, :] = seq_out + + return output + + +def _paddle_chunk_gated_delta_rule( + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + chunk_size: int = 64, + initial_state: Optional[paddle.Tensor] = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple: + """Chunked Gated Delta Rule (pure Paddle fallback for prefill). + + Args: + query: [batch, seq_len, num_heads, head_k_dim] + key: [batch, seq_len, num_heads, head_k_dim] + value: [batch, seq_len, num_heads, head_v_dim] + g: [batch, seq_len, num_heads] + beta: [batch, seq_len, num_heads] + Returns: + (output, last_state) + """ + initial_dtype = query.dtype + + if use_qk_l2norm_in_kernel: + query = l2norm(query.cast(paddle.float32), axis=-1) + key = l2norm(key.cast(paddle.float32), axis=-1) + + query = query.transpose([0, 2, 1, 3]).cast(paddle.float32) + key = key.transpose([0, 2, 1, 3]).cast(paddle.float32) + value = value.transpose([0, 2, 1, 3]).cast(paddle.float32) + beta = beta.transpose([0, 2, 1]).cast(paddle.float32) + g = g.transpose([0, 2, 1]).cast(paddle.float32) + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + query = paddle.nn.functional.pad(query, [0, 0, 0, pad_size]) + key = paddle.nn.functional.pad(key, [0, 0, 0, pad_size]) + value = paddle.nn.functional.pad(value, [0, 0, 0, pad_size]) + beta = paddle.nn.functional.pad(beta, [0, pad_size]) + g = paddle.nn.functional.pad(g, [0, pad_size]) + total_sequence_length = sequence_length + pad_size + + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + + query = query.reshape([batch_size, num_heads, -1, chunk_size, k_head_dim]) + key = key.reshape([batch_size, num_heads, -1, chunk_size, k_head_dim]) + value = value.reshape([batch_size, num_heads, -1, chunk_size, v_head_dim]) + k_beta = k_beta.reshape([batch_size, num_heads, -1, chunk_size, k_head_dim]) + v_beta = v_beta.reshape([batch_size, num_heads, -1, chunk_size, v_head_dim]) + g = g.reshape([batch_size, num_heads, -1, chunk_size]) + + g = g.cumsum(axis=-1) + decay_mask = paddle.tril(paddle.exp((g.unsqueeze(-1) - g.unsqueeze(-2)).tril())) + + mask = paddle.triu(paddle.ones([chunk_size, chunk_size], dtype="bool"), diagonal=0) + attn = -((k_beta @ key.transpose([0, 1, 2, 4, 3])) * decay_mask) + attn = paddle.where(mask, paddle.zeros_like(attn), attn) + + for i in range(1, chunk_size): + row = attn[:, :, :, i, :i].clone() + sub = attn[:, :, :, :i, :i].clone() + attn[:, :, :, i, :i] = row + (row.unsqueeze(-1) * sub).sum(axis=-2) + + attn = attn + paddle.eye(chunk_size, dtype=attn.dtype) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + last_recurrent_state = ( + paddle.zeros([batch_size, num_heads, k_head_dim, v_head_dim]) + if initial_state is None + else initial_state.cast(value.dtype) + ) + + core_attn_out = paddle.zeros_like(value) + mask = paddle.triu(paddle.ones([chunk_size, chunk_size], dtype="bool"), diagonal=1) + + num_chunks = total_sequence_length // chunk_size + for i in range(num_chunks): + q_i = query[:, :, i] + k_i = key[:, :, i] + v_i = value[:, :, i] + + attn_i = q_i @ k_i.transpose([0, 1, 3, 2]) * decay_mask[:, :, i] + attn_i = paddle.where(mask, paddle.zeros_like(attn_i), attn_i) + v_prime = k_cumdecay[:, :, i] @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn_i @ v_new + + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose([0, 1, 3, 2]) @ v_new + ) + + if not output_final_state: + last_recurrent_state = None + + core_attn_out = core_attn_out.reshape([batch_size, num_heads, -1, v_head_dim]) + core_attn_out = core_attn_out[:, :, :sequence_length] + core_attn_out = core_attn_out.transpose([0, 2, 1, 3]).cast(initial_dtype) + + return core_attn_out, last_recurrent_state + + +def _paddle_recurrent_gated_delta_rule( + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + initial_state: Optional[paddle.Tensor] = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple: + """Recurrent Gated Delta Rule (pure Paddle fallback for decode). + + Args: + query: [batch, seq_len, num_heads, head_k_dim] + key: [batch, seq_len, num_heads, head_k_dim] + value: [batch, seq_len, num_heads, head_v_dim] + g: [batch, seq_len, num_heads] + beta: [batch, seq_len, num_heads] + Returns: + (output, last_state) + """ + initial_dtype = query.dtype + + if use_qk_l2norm_in_kernel: + query = l2norm(query.cast(paddle.float32), axis=-1) + key = l2norm(key.cast(paddle.float32), axis=-1) + + query = query.transpose([0, 2, 1, 3]).cast(paddle.float32) + key = key.transpose([0, 2, 1, 3]).cast(paddle.float32) + value = value.transpose([0, 2, 1, 3]).cast(paddle.float32) + beta = beta.transpose([0, 2, 1]).cast(paddle.float32) + g = g.transpose([0, 2, 1]).cast(paddle.float32) + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = paddle.zeros([batch_size, num_heads, sequence_length, v_head_dim]) + last_recurrent_state = ( + paddle.zeros([batch_size, num_heads, k_head_dim, v_head_dim]) + if initial_state is None + else initial_state.cast(core_attn_out.dtype) + ) + + for i in range(sequence_length): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(axis=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(axis=-2) + + if not output_final_state: + last_recurrent_state = None + + core_attn_out = core_attn_out.transpose([0, 2, 1, 3]).cast(initial_dtype) + return core_attn_out, last_recurrent_state + + +def _fused_recurrent_gated_delta_rule_fallback( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + initial_state: Optional[paddle.Tensor] = None, +) -> tuple: + """Fallback wrapper: converts 3D [num_tokens, H, D] to 4D and dispatches. + + Returns: + (output [num_tokens, H, DV], last_state) + """ + num_tokens = q.shape[0] + q_4d = q.unsqueeze(0) + k_4d = k.unsqueeze(0) + v_4d = v.unsqueeze(0) + g_4d = g.unsqueeze(0) + beta_4d = beta.unsqueeze(0) + + if num_tokens > 1: + out, last_state = _paddle_chunk_gated_delta_rule( + q_4d, + k_4d, + v_4d, + g_4d, + beta_4d, + chunk_size=64, + initial_state=initial_state, + output_final_state=False, + ) + else: + out, last_state = _paddle_recurrent_gated_delta_rule( + q_4d, + k_4d, + v_4d, + g_4d, + beta_4d, + initial_state=initial_state, + output_final_state=False, + ) + return out.squeeze(0), last_state + + +# ============================================================================== +# Kernel Dispatcher (strategy pattern, inspired by SGLang GDNKernelDispatcher) +# ============================================================================== + + +class GDNKernelDispatcher: + """Strategy pattern — routes SSM kernel calls to Triton or Paddle fallback. + + Currently: Triton (FLA) kernels only. + Future: add FlashInfer/CUDA by selecting different kernels at construction time. + """ + + def decode( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + *, + ssm_pool: paddle.Tensor, + slot_ids: paddle.Tensor, + ) -> paddle.Tensor: + """Decode: fused recurrent kernel (pool-indexed, in-place). + + Args: + q,k: [batch, 1, H, DK] + v: [batch, 1, H, DV] + g,beta: [batch, 1, H] + ssm_pool: [pool_size, H, K, V] + slot_ids: [batch] int32 (already offset: PAD→0) + Returns: + [batch, 1, H, DV] + """ + from fastdeploy.model_executor.ops.triton_ops.fla import ( + fused_recurrent_gated_delta_rule_update, + ) + + return fused_recurrent_gated_delta_rule_update( + q=q, + k=k, + v=v, + g=g, + beta=beta, + ssm_pool=ssm_pool, + ssm_indices=slot_ids, + use_qk_l2norm_in_kernel=True, + ) + + def extend( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + *, + ssm_pool: paddle.Tensor, + slot_ids: paddle.Tensor, + cu_seqlens: paddle.Tensor, + ) -> paddle.Tensor: + """Prefill/extend: chunk kernel + state writeback. + + Args: + q,k: [1, total_tokens, H, DK] + v: [1, total_tokens, H, DV] + g,beta: [1, total_tokens, H] + ssm_pool: [pool_size, H, K, V] + slot_ids: [batch] int32 (already offset: PAD→0) + cu_seqlens: [batch+1] int32 + Returns: + [1, total_tokens, H, DV] + """ + from fastdeploy.model_executor.ops.triton_ops.fla import chunk_gated_delta_rule + + # Clone initial state from pool — chunk kernel updates in-place + initial_state = ssm_pool[slot_ids].clone() # [batch, H, K, V] + + o, _h = chunk_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=None, + initial_state=initial_state, + initial_state_indices=paddle.arange(slot_ids.shape[0], dtype=paddle.int32), + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True, + ) + + # Write back final states (updated in-place by the kernel) + batch_size = slot_ids.shape[0] + for i in range(batch_size): + sid = int(slot_ids[i]) + if sid > 0: # skip padding sentinel (slot 0) + ssm_pool[sid] = initial_state[i] + + return o + + def decode_fallback( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + ) -> paddle.Tensor: + """Fallback decode: pure Paddle recurrent (no pool). + + Args: q,k,v,g,beta in 3D [num_tokens, H, D] format + Returns: [num_tokens, H, DV] + """ + return _fused_recurrent_gated_delta_rule_fallback(q, k, v, g, beta)[0] + + def extend_fallback( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + ) -> paddle.Tensor: + """Fallback extend: pure Paddle chunk (no pool). + + Args: q,k,v,g,beta in 3D [num_tokens, H, D] format + Returns: [num_tokens, H, DV] + """ + return _fused_recurrent_gated_delta_rule_fallback(q, k, v, g, beta)[0] + + +# ============================================================================== +# GDN Attention Backend +# ============================================================================== + + +class GDNAttentionBackend(AttentionBackend): + """GDN (Gated Delta Network) linear attention backend. + + Inherits AttentionBackend for formal consistency with FD/vLLM/SGLang. + The model layer calls forward_meta.gdn_attn_backend.forward() directly + (not through the standard Attention trampoline). + + Internal flow: + 1. Causal Conv1d (decode: triton update / prefill: triton varlen / fallback) + 2. Split Q,K,V + fused_gdn_gating + 3. GVA repeat (if num_v_heads > num_k_heads) + 4. SSM kernel via GDNKernelDispatcher (decode / extend / fallback) + """ + + def __init__(self): + self.kernel_dispatcher = GDNKernelDispatcher() + + def init_attention_metadata(self, forward_meta: ForwardMeta): + """GDN does not need standard attention metadata.""" + pass + + def forward( + self, + mixed_qkv: paddle.Tensor, + a: paddle.Tensor, + b: paddle.Tensor, + layer, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + """Unified forward entry — model layer calls this single method. + + Args: + mixed_qkv: [num_tokens, conv_dim] — projected QKV (after in_proj_qkv) + a: [num_tokens, num_v_heads] — alpha gating input + b: [num_tokens, num_v_heads] — beta gating input + layer: Qwen3_5GatedDeltaNet instance (provides conv_weight, A_log, dt_bias, dims) + forward_meta: ForwardMeta (contains pool, slot_ids, forward_mode, etc.) + + Returns: + [num_tokens, num_v_heads_local, head_v_dim] + """ + from fastdeploy.cache_manager.gdn_state_pool import GDNStatePool + + num_tokens = mixed_qkv.shape[0] + is_decode = forward_meta.forward_mode.is_decode() + + # Get pool views for this layer + gdn_pool = forward_meta.gdn_state_pool + raw_slot_ids = forward_meta.gdn_slot_ids + + conv_pool = gdn_pool.get_layer_conv_pool(layer.gdn_layer_idx) if gdn_pool is not None else None + ssm_pool = gdn_pool.get_layer_ssm_pool(layer.gdn_layer_idx) if gdn_pool is not None else None + + # Offset slot_ids: PAD_SLOT_ID (-1) → slot 0 + slot_ids = GDNStatePool.offset_slot_ids(raw_slot_ids) if raw_slot_ids is not None else None + + has_pool = conv_pool is not None and slot_ids is not None + + # ============================================================ + # 1. Causal Conv1d + # ============================================================ + conv_weight_local = layer.conv1d_weight[: layer.conv_dim] + + if is_decode and has_pool: + from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import ( + causal_conv1d_update as triton_conv1d_update, + ) + + mixed_qkv = triton_conv1d_update( + x=mixed_qkv, + conv_state=conv_pool, + weight=conv_weight_local, + bias=None, + activation="silu", + conv_state_indices=slot_ids, + ) + elif not is_decode and has_pool: + from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import ( + causal_conv1d_fn as triton_conv1d_fn, + ) + + cu_seqlens = forward_meta.cu_seqlens_q + mixed_qkv = triton_conv1d_fn( + x=mixed_qkv.T, # [dim, total_tokens] + weight=conv_weight_local, + bias=None, + conv_states=conv_pool, + query_start_loc=cu_seqlens, + seq_lens_cpu=forward_meta.gdn_seq_lens_cpu, + cache_indices=slot_ids, + has_initial_state=forward_meta.gdn_has_initial_state, + activation="silu", + ).T # [total_tokens, dim] + else: + cu_seqlens = forward_meta.cu_seqlens_q + if cu_seqlens is None and forward_meta.seq_lens_this_time is not None: + seq_lens = forward_meta.seq_lens_this_time.numpy() + cu_seqlens_list = [0] + for sl in seq_lens: + cu_seqlens_list.append(cu_seqlens_list[-1] + sl) + cu_seqlens = paddle.to_tensor(cu_seqlens_list, dtype="int32") + mixed_qkv = _causal_conv1d_fn_fallback( + mixed_qkv, + conv_weight_local, + bias=None, + activation="silu", + cu_seqlens=cu_seqlens, + ) + + # ============================================================ + # 2. Split Q, K, V + # ============================================================ + key_dim_local = layer.num_k_heads_local * layer.head_k_dim + value_dim_local = layer.num_v_heads_local * layer.head_v_dim + + q, k, v = paddle.split( + mixed_qkv, + [ + key_dim_local, + key_dim_local, + value_dim_local, + ], + axis=-1, + ) + + q = q.reshape([num_tokens, layer.num_k_heads_local, layer.head_k_dim]) + k = k.reshape([num_tokens, layer.num_k_heads_local, layer.head_k_dim]) + v = v.reshape([num_tokens, layer.num_v_heads_local, layer.head_v_dim]) + + # ============================================================ + # 3. GDN Gating + # ============================================================ + A_log_local = layer.A_log[: layer.num_v_heads_local] + dt_bias_local = layer.dt_bias[: layer.num_v_heads_local] + a_local = a[:, : layer.num_v_heads_local] + b_local = b[:, : layer.num_v_heads_local] + + g, beta = fused_gdn_gating(A_log_local, a_local, b_local, dt_bias_local) + + # ============================================================ + # 4. GVA repeat (if num_v_heads > num_k_heads) + # ============================================================ + if layer.num_v_heads_per_k_head > 1: + q = ( + q.unsqueeze(2) + .expand([num_tokens, layer.num_k_heads_local, layer.num_v_heads_per_k_head, layer.head_k_dim]) + .reshape([num_tokens, layer.num_v_heads_local, layer.head_k_dim]) + ) + k = ( + k.unsqueeze(2) + .expand([num_tokens, layer.num_k_heads_local, layer.num_v_heads_per_k_head, layer.head_k_dim]) + .reshape([num_tokens, layer.num_v_heads_local, layer.head_k_dim]) + ) + + # ============================================================ + # 5. Core SSM Attention (via dispatcher) + # ============================================================ + if is_decode and has_pool: + # Decode: fused recurrent (Triton) + q_4d = q.unsqueeze(1) # [batch, 1, H, K] + k_4d = k.unsqueeze(1) + v_4d = v.unsqueeze(1) + g_4d = g.unsqueeze(1) # [batch, 1, H] + beta_4d = beta.unsqueeze(1) + + o = self.kernel_dispatcher.decode( + q_4d, + k_4d, + v_4d, + g_4d, + beta_4d, + ssm_pool=ssm_pool, + slot_ids=slot_ids, + ) + core_attn_out = o.squeeze(1) # [batch, H, V] + + elif not is_decode and has_pool: + # Prefill/extend: chunk (Triton) + cu_seqlens = forward_meta.cu_seqlens_q + q_4d = q.unsqueeze(0) # [1, total_tokens, H, K] + k_4d = k.unsqueeze(0) + v_4d = v.unsqueeze(0) + g_4d = g.unsqueeze(0) # [1, total_tokens, H] + beta_4d = beta.unsqueeze(0) + + o = self.kernel_dispatcher.extend( + q_4d, + k_4d, + v_4d, + g_4d, + beta_4d, + ssm_pool=ssm_pool, + slot_ids=slot_ids, + cu_seqlens=cu_seqlens, + ) + core_attn_out = o.squeeze(0) # [total_tokens, H, V] + + else: + # Fallback: pure Paddle (no pool) + core_attn_out = self.kernel_dispatcher.extend_fallback(q, k, v, g, beta) + + return core_attn_out diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 7682caac64c..268c675a070 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1367,6 +1367,9 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): # Derive gdn_seq_lens_cpu from seq_lens_this_time if self.forward_meta.seq_lens_this_time is not None: self.forward_meta.gdn_seq_lens_cpu = self.forward_meta.seq_lens_this_time.numpy().tolist() + # GDN attention backend + if getattr(self, "gdn_attn_backend", None) is not None: + self.forward_meta.gdn_attn_backend = self.gdn_attn_backend def initialize_kv_cache(self, profile: bool = False) -> None: """ @@ -1576,6 +1579,14 @@ def _initialize_gdn_state_pool(self) -> None: f"num_v_heads_local={num_v_heads_local}" ) + # Initialize GDN attention backend (kernel dispatcher + unified forward) + from fastdeploy.model_executor.layers.attention.gdn_backend import ( + GDNAttentionBackend, + ) + + self.gdn_attn_backend = GDNAttentionBackend() + logger.info("GDN attention backend initialized (GDNKernelDispatcher + GDNAttentionBackend)") + def _initialize_attn_backend(self) -> None: """ Initialize attention backends From 43405fa2603e2c70190c0d6f876e0b49371a18b9 Mon Sep 17 00:00:00 2001 From: liukebin Date: Tue, 31 Mar 2026 13:17:45 +0800 Subject: [PATCH 3/6] [Feature] support gdn backend for qwen3.5 --- .../layers/attention/__init__.py | 2 + .../layers/attention/gdn_attention.py | 75 +++++++++++++++++++ 2 files changed, 77 insertions(+) create mode 100644 fastdeploy/model_executor/layers/attention/gdn_attention.py diff --git a/fastdeploy/model_executor/layers/attention/__init__.py b/fastdeploy/model_executor/layers/attention/__init__.py index 7efc3259fbc..c75f7e4de18 100644 --- a/fastdeploy/model_executor/layers/attention/__init__.py +++ b/fastdeploy/model_executor/layers/attention/__init__.py @@ -20,6 +20,7 @@ from .dsa_attention_backend import DSAAttentionBackend from .flash_attn_backend import FlashAttentionBackend from .flash_mask_attn_backend import FlashMaskAttentionBackend +from .gdn_attention import GDNAttention from .mla_attention_backend import MLAAttentionBackend from .moba_attention_backend import PlasAttentionBackend from .native_paddle_backend import PaddleNativeAttnBackend @@ -34,6 +35,7 @@ "FlashAttentionBackend", "BlockAttentionBackend", "Attention", + "GDNAttention", "PlasAttentionBackend", "FlashMaskAttentionBackend", ] diff --git a/fastdeploy/model_executor/layers/attention/gdn_attention.py b/fastdeploy/model_executor/layers/attention/gdn_attention.py new file mode 100644 index 00000000000..fc9f67a10aa --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/gdn_attention.py @@ -0,0 +1,75 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GDN linear attention trampoline layer. + +Analogous to ``Attention`` for standard softmax attention — instantiated in the +model layer's ``__init__`` and called in ``forward``, delegating to the backend +on ``forward_meta``. + +Usage:: + + class Qwen3_5GatedDeltaNet(nn.Layer): + def __init__(self, ...): + ... + self.gdn_attn = GDNAttention() + + def forward(self, forward_meta, hidden_states): + ... + core_attn_out = self.gdn_attn(mixed_qkv, a, b, self, forward_meta) +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import paddle +from paddle import nn + +if TYPE_CHECKING: + from fastdeploy.model_executor.forward_meta import ForwardMeta + + +class GDNAttention(nn.Layer): + """GDN (Gated Delta Network) linear attention trampoline. + + Mirrors the role of :class:`Attention` for softmax attention: + the model layer holds ``self.gdn_attn = GDNAttention()`` and calls + ``self.gdn_attn(mixed_qkv, a, b, self, forward_meta)`` in its forward. + + Internally delegates to ``forward_meta.gdn_attn_backend.forward()``. + """ + + def forward( + self, + mixed_qkv: paddle.Tensor, + a: paddle.Tensor, + b: paddle.Tensor, + layer: nn.Layer, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + """Forward pass — delegates to the GDN attention backend. + + Args: + mixed_qkv: Projected QKV tensor ``[num_tokens, conv_dim]``. + a: Gating input a ``[num_tokens, num_v_heads]``. + b: Gating input b ``[num_tokens, num_v_heads]``. + layer: The parent ``Qwen3_5GatedDeltaNet`` instance (provides + ``conv_weight``, ``A_log_local``, ``dt_bias_local``, etc.). + forward_meta: Per-step forward metadata. + + Returns: + Attention output ``[num_tokens, num_v_heads_local, head_v_dim]``. + """ + return forward_meta.gdn_attn_backend.forward(mixed_qkv, a, b, layer, forward_meta) From 44d907484f9eb6ee2ea165860145333d075e2b02 Mon Sep 17 00:00:00 2001 From: liukebin Date: Tue, 31 Mar 2026 13:35:49 +0800 Subject: [PATCH 4/6] [Feature] support gdn backend for qwen3.5 --- .../layers/attention/gdn_backend.py | 386 ++---------------- 1 file changed, 26 insertions(+), 360 deletions(-) diff --git a/fastdeploy/model_executor/layers/attention/gdn_backend.py b/fastdeploy/model_executor/layers/attention/gdn_backend.py index 5384008f66e..cf61a78bb13 100644 --- a/fastdeploy/model_executor/layers/attention/gdn_backend.py +++ b/fastdeploy/model_executor/layers/attention/gdn_backend.py @@ -15,30 +15,30 @@ GDN (Gated Delta Network) Attention Backend for Qwen3.5 linear attention. Architecture (inspired by SGLang's gdn_backend.py): - - GDNKernelDispatcher: strategy pattern routing to Triton / Paddle fallback kernels + - GDNKernelDispatcher: strategy pattern routing to Triton (FLA) kernels - GDNAttentionBackend(AttentionBackend): unified forward entry for model layer Call chain: Model Layer: Qwen3_5GatedDeltaNet.forward() |-- projections (qkv, z, b, a) - |-- forward_meta.gdn_attn_backend.forward(mixed_qkv, a, b, layer, forward_meta) - | - GDNAttentionBackend.forward() - |-- causal_conv1d (decode / prefill / fallback) - |-- split Q,K,V + fused_gdn_gating - |-- GVA repeat (if needed) - |-- kernel_dispatcher.decode/extend/fallback(...) - | - GDNKernelDispatcher - |-- Triton FLA kernels (decode: fused_recurrent, extend: chunk) - |-- Paddle fallback (paddle_recurrent / paddle_chunk) + |-- self.gdn_attn(mixed_qkv, a, b, self, forward_meta) [GDNAttention trampoline] + |-- forward_meta.gdn_attn_backend.forward(mixed_qkv, a, b, layer, forward_meta) + | + GDNAttentionBackend.forward() + |-- causal_conv1d (decode: triton update / prefill: triton varlen) + |-- split Q,K,V + fused_gdn_gating + |-- GVA repeat (if needed) + |-- kernel_dispatcher.decode/extend(...) + | + GDNKernelDispatcher + |-- Triton FLA kernels (decode: fused_recurrent, extend: chunk) |-- gated RMSNorm + output projection """ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import paddle @@ -88,297 +88,13 @@ def fused_gdn_gating( return g, beta -def _causal_conv1d_single_seq( - x: paddle.Tensor, - conv_weights: paddle.Tensor, - bias: Optional[paddle.Tensor], - activation: str, - kernel_size: int, -) -> paddle.Tensor: - """Apply causal conv1d to a single sequence (pure Paddle fallback). - - Args: - x: [seq_len, channels] - conv_weights: [channels, kernel_size] - Returns: - [seq_len, channels] - """ - seq_len, channels = x.shape - x = x.transpose([1, 0]).unsqueeze(0) # [1, channels, seq_len] - weight = conv_weights.unsqueeze(1) # [channels, 1, kernel_size] - padding = kernel_size - 1 - x = paddle.nn.functional.conv1d(x, weight, bias, padding=padding, groups=channels) - x = x[:, :, :seq_len] - x = x.squeeze(0).transpose([1, 0]) # [seq_len, channels] - if activation == "silu": - x = paddle.nn.functional.silu(x) - return x - - -def _causal_conv1d_fn_fallback( - x: paddle.Tensor, - conv_weights: paddle.Tensor, - bias: Optional[paddle.Tensor] = None, - activation: str = "silu", - cu_seqlens: Optional[paddle.Tensor] = None, -) -> paddle.Tensor: - """Causal conv1d for packed sequences (pure Paddle fallback). - - Args: - x: [num_tokens, channels] - conv_weights: [channels, kernel_size] - cu_seqlens: [batch_size + 1] cumulative sequence lengths - Returns: - [num_tokens, channels] - """ - kernel_size = conv_weights.shape[-1] - if cu_seqlens is None: - return _causal_conv1d_single_seq(x, conv_weights, bias, activation, kernel_size) - - batch_size = cu_seqlens.shape[0] - 1 - cu_seqlens_np = cu_seqlens.numpy() - output = paddle.zeros_like(x) - - for i in range(batch_size): - start = int(cu_seqlens_np[i]) - end = int(cu_seqlens_np[i + 1]) - if end <= start: - continue - seq_x = x[start:end, :] - seq_out = _causal_conv1d_single_seq(seq_x, conv_weights, bias, activation, kernel_size) - output[start:end, :] = seq_out - - return output - - -def _paddle_chunk_gated_delta_rule( - query: paddle.Tensor, - key: paddle.Tensor, - value: paddle.Tensor, - g: paddle.Tensor, - beta: paddle.Tensor, - chunk_size: int = 64, - initial_state: Optional[paddle.Tensor] = None, - output_final_state: bool = False, - use_qk_l2norm_in_kernel: bool = False, -) -> tuple: - """Chunked Gated Delta Rule (pure Paddle fallback for prefill). - - Args: - query: [batch, seq_len, num_heads, head_k_dim] - key: [batch, seq_len, num_heads, head_k_dim] - value: [batch, seq_len, num_heads, head_v_dim] - g: [batch, seq_len, num_heads] - beta: [batch, seq_len, num_heads] - Returns: - (output, last_state) - """ - initial_dtype = query.dtype - - if use_qk_l2norm_in_kernel: - query = l2norm(query.cast(paddle.float32), axis=-1) - key = l2norm(key.cast(paddle.float32), axis=-1) - - query = query.transpose([0, 2, 1, 3]).cast(paddle.float32) - key = key.transpose([0, 2, 1, 3]).cast(paddle.float32) - value = value.transpose([0, 2, 1, 3]).cast(paddle.float32) - beta = beta.transpose([0, 2, 1]).cast(paddle.float32) - g = g.transpose([0, 2, 1]).cast(paddle.float32) - - batch_size, num_heads, sequence_length, k_head_dim = key.shape - v_head_dim = value.shape[-1] - - pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size - query = paddle.nn.functional.pad(query, [0, 0, 0, pad_size]) - key = paddle.nn.functional.pad(key, [0, 0, 0, pad_size]) - value = paddle.nn.functional.pad(value, [0, 0, 0, pad_size]) - beta = paddle.nn.functional.pad(beta, [0, pad_size]) - g = paddle.nn.functional.pad(g, [0, pad_size]) - total_sequence_length = sequence_length + pad_size - - scale = 1.0 / (query.shape[-1] ** 0.5) - query = query * scale - - v_beta = value * beta.unsqueeze(-1) - k_beta = key * beta.unsqueeze(-1) - - query = query.reshape([batch_size, num_heads, -1, chunk_size, k_head_dim]) - key = key.reshape([batch_size, num_heads, -1, chunk_size, k_head_dim]) - value = value.reshape([batch_size, num_heads, -1, chunk_size, v_head_dim]) - k_beta = k_beta.reshape([batch_size, num_heads, -1, chunk_size, k_head_dim]) - v_beta = v_beta.reshape([batch_size, num_heads, -1, chunk_size, v_head_dim]) - g = g.reshape([batch_size, num_heads, -1, chunk_size]) - - g = g.cumsum(axis=-1) - decay_mask = paddle.tril(paddle.exp((g.unsqueeze(-1) - g.unsqueeze(-2)).tril())) - - mask = paddle.triu(paddle.ones([chunk_size, chunk_size], dtype="bool"), diagonal=0) - attn = -((k_beta @ key.transpose([0, 1, 2, 4, 3])) * decay_mask) - attn = paddle.where(mask, paddle.zeros_like(attn), attn) - - for i in range(1, chunk_size): - row = attn[:, :, :, i, :i].clone() - sub = attn[:, :, :, :i, :i].clone() - attn[:, :, :, i, :i] = row + (row.unsqueeze(-1) * sub).sum(axis=-2) - - attn = attn + paddle.eye(chunk_size, dtype=attn.dtype) - value = attn @ v_beta - k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) - - last_recurrent_state = ( - paddle.zeros([batch_size, num_heads, k_head_dim, v_head_dim]) - if initial_state is None - else initial_state.cast(value.dtype) - ) - - core_attn_out = paddle.zeros_like(value) - mask = paddle.triu(paddle.ones([chunk_size, chunk_size], dtype="bool"), diagonal=1) - - num_chunks = total_sequence_length // chunk_size - for i in range(num_chunks): - q_i = query[:, :, i] - k_i = key[:, :, i] - v_i = value[:, :, i] - - attn_i = q_i @ k_i.transpose([0, 1, 3, 2]) * decay_mask[:, :, i] - attn_i = paddle.where(mask, paddle.zeros_like(attn_i), attn_i) - v_prime = k_cumdecay[:, :, i] @ last_recurrent_state - v_new = v_i - v_prime - attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - core_attn_out[:, :, i] = attn_inter + attn_i @ v_new - - last_recurrent_state = ( - last_recurrent_state * g[:, :, i, -1, None, None].exp() - + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose([0, 1, 3, 2]) @ v_new - ) - - if not output_final_state: - last_recurrent_state = None - - core_attn_out = core_attn_out.reshape([batch_size, num_heads, -1, v_head_dim]) - core_attn_out = core_attn_out[:, :, :sequence_length] - core_attn_out = core_attn_out.transpose([0, 2, 1, 3]).cast(initial_dtype) - - return core_attn_out, last_recurrent_state - - -def _paddle_recurrent_gated_delta_rule( - query: paddle.Tensor, - key: paddle.Tensor, - value: paddle.Tensor, - g: paddle.Tensor, - beta: paddle.Tensor, - initial_state: Optional[paddle.Tensor] = None, - output_final_state: bool = False, - use_qk_l2norm_in_kernel: bool = False, -) -> tuple: - """Recurrent Gated Delta Rule (pure Paddle fallback for decode). - - Args: - query: [batch, seq_len, num_heads, head_k_dim] - key: [batch, seq_len, num_heads, head_k_dim] - value: [batch, seq_len, num_heads, head_v_dim] - g: [batch, seq_len, num_heads] - beta: [batch, seq_len, num_heads] - Returns: - (output, last_state) - """ - initial_dtype = query.dtype - - if use_qk_l2norm_in_kernel: - query = l2norm(query.cast(paddle.float32), axis=-1) - key = l2norm(key.cast(paddle.float32), axis=-1) - - query = query.transpose([0, 2, 1, 3]).cast(paddle.float32) - key = key.transpose([0, 2, 1, 3]).cast(paddle.float32) - value = value.transpose([0, 2, 1, 3]).cast(paddle.float32) - beta = beta.transpose([0, 2, 1]).cast(paddle.float32) - g = g.transpose([0, 2, 1]).cast(paddle.float32) - - batch_size, num_heads, sequence_length, k_head_dim = key.shape - v_head_dim = value.shape[-1] - - scale = 1.0 / (query.shape[-1] ** 0.5) - query = query * scale - - core_attn_out = paddle.zeros([batch_size, num_heads, sequence_length, v_head_dim]) - last_recurrent_state = ( - paddle.zeros([batch_size, num_heads, k_head_dim, v_head_dim]) - if initial_state is None - else initial_state.cast(core_attn_out.dtype) - ) - - for i in range(sequence_length): - q_t = query[:, :, i] - k_t = key[:, :, i] - v_t = value[:, :, i] - g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) - beta_t = beta[:, :, i].unsqueeze(-1) - - last_recurrent_state = last_recurrent_state * g_t - kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(axis=-2) - delta = (v_t - kv_mem) * beta_t - last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) - core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(axis=-2) - - if not output_final_state: - last_recurrent_state = None - - core_attn_out = core_attn_out.transpose([0, 2, 1, 3]).cast(initial_dtype) - return core_attn_out, last_recurrent_state - - -def _fused_recurrent_gated_delta_rule_fallback( - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, - g: paddle.Tensor, - beta: paddle.Tensor, - initial_state: Optional[paddle.Tensor] = None, -) -> tuple: - """Fallback wrapper: converts 3D [num_tokens, H, D] to 4D and dispatches. - - Returns: - (output [num_tokens, H, DV], last_state) - """ - num_tokens = q.shape[0] - q_4d = q.unsqueeze(0) - k_4d = k.unsqueeze(0) - v_4d = v.unsqueeze(0) - g_4d = g.unsqueeze(0) - beta_4d = beta.unsqueeze(0) - - if num_tokens > 1: - out, last_state = _paddle_chunk_gated_delta_rule( - q_4d, - k_4d, - v_4d, - g_4d, - beta_4d, - chunk_size=64, - initial_state=initial_state, - output_final_state=False, - ) - else: - out, last_state = _paddle_recurrent_gated_delta_rule( - q_4d, - k_4d, - v_4d, - g_4d, - beta_4d, - initial_state=initial_state, - output_final_state=False, - ) - return out.squeeze(0), last_state - - # ============================================================================== # Kernel Dispatcher (strategy pattern, inspired by SGLang GDNKernelDispatcher) # ============================================================================== class GDNKernelDispatcher: - """Strategy pattern — routes SSM kernel calls to Triton or Paddle fallback. + """Strategy pattern — routes SSM kernel calls to Triton (FLA) kernels. Currently: Triton (FLA) kernels only. Future: add FlashInfer/CUDA by selecting different kernels at construction time. @@ -472,36 +188,6 @@ def extend( return o - def decode_fallback( - self, - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, - g: paddle.Tensor, - beta: paddle.Tensor, - ) -> paddle.Tensor: - """Fallback decode: pure Paddle recurrent (no pool). - - Args: q,k,v,g,beta in 3D [num_tokens, H, D] format - Returns: [num_tokens, H, DV] - """ - return _fused_recurrent_gated_delta_rule_fallback(q, k, v, g, beta)[0] - - def extend_fallback( - self, - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, - g: paddle.Tensor, - beta: paddle.Tensor, - ) -> paddle.Tensor: - """Fallback extend: pure Paddle chunk (no pool). - - Args: q,k,v,g,beta in 3D [num_tokens, H, D] format - Returns: [num_tokens, H, DV] - """ - return _fused_recurrent_gated_delta_rule_fallback(q, k, v, g, beta)[0] - # ============================================================================== # GDN Attention Backend @@ -512,14 +198,15 @@ class GDNAttentionBackend(AttentionBackend): """GDN (Gated Delta Network) linear attention backend. Inherits AttentionBackend for formal consistency with FD/vLLM/SGLang. - The model layer calls forward_meta.gdn_attn_backend.forward() directly - (not through the standard Attention trampoline). + The model layer calls through GDNAttention trampoline: + self.gdn_attn(mixed_qkv, a, b, self, forward_meta) + → forward_meta.gdn_attn_backend.forward(...) Internal flow: - 1. Causal Conv1d (decode: triton update / prefill: triton varlen / fallback) + 1. Causal Conv1d (decode: triton update / prefill: triton varlen) 2. Split Q,K,V + fused_gdn_gating 3. GVA repeat (if num_v_heads > num_k_heads) - 4. SSM kernel via GDNKernelDispatcher (decode / extend / fallback) + 4. SSM kernel via GDNKernelDispatcher (decode / extend) """ def __init__(self): @@ -558,20 +245,18 @@ def forward( gdn_pool = forward_meta.gdn_state_pool raw_slot_ids = forward_meta.gdn_slot_ids - conv_pool = gdn_pool.get_layer_conv_pool(layer.gdn_layer_idx) if gdn_pool is not None else None - ssm_pool = gdn_pool.get_layer_ssm_pool(layer.gdn_layer_idx) if gdn_pool is not None else None + conv_pool = gdn_pool.get_layer_conv_pool(layer.gdn_layer_idx) + ssm_pool = gdn_pool.get_layer_ssm_pool(layer.gdn_layer_idx) # Offset slot_ids: PAD_SLOT_ID (-1) → slot 0 - slot_ids = GDNStatePool.offset_slot_ids(raw_slot_ids) if raw_slot_ids is not None else None - - has_pool = conv_pool is not None and slot_ids is not None + slot_ids = GDNStatePool.offset_slot_ids(raw_slot_ids) # ============================================================ # 1. Causal Conv1d # ============================================================ conv_weight_local = layer.conv1d_weight[: layer.conv_dim] - if is_decode and has_pool: + if is_decode: from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import ( causal_conv1d_update as triton_conv1d_update, ) @@ -584,7 +269,7 @@ def forward( activation="silu", conv_state_indices=slot_ids, ) - elif not is_decode and has_pool: + else: from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import ( causal_conv1d_fn as triton_conv1d_fn, ) @@ -601,21 +286,6 @@ def forward( has_initial_state=forward_meta.gdn_has_initial_state, activation="silu", ).T # [total_tokens, dim] - else: - cu_seqlens = forward_meta.cu_seqlens_q - if cu_seqlens is None and forward_meta.seq_lens_this_time is not None: - seq_lens = forward_meta.seq_lens_this_time.numpy() - cu_seqlens_list = [0] - for sl in seq_lens: - cu_seqlens_list.append(cu_seqlens_list[-1] + sl) - cu_seqlens = paddle.to_tensor(cu_seqlens_list, dtype="int32") - mixed_qkv = _causal_conv1d_fn_fallback( - mixed_qkv, - conv_weight_local, - bias=None, - activation="silu", - cu_seqlens=cu_seqlens, - ) # ============================================================ # 2. Split Q, K, V @@ -665,7 +335,7 @@ def forward( # ============================================================ # 5. Core SSM Attention (via dispatcher) # ============================================================ - if is_decode and has_pool: + if is_decode: # Decode: fused recurrent (Triton) q_4d = q.unsqueeze(1) # [batch, 1, H, K] k_4d = k.unsqueeze(1) @@ -684,7 +354,7 @@ def forward( ) core_attn_out = o.squeeze(1) # [batch, H, V] - elif not is_decode and has_pool: + else: # Prefill/extend: chunk (Triton) cu_seqlens = forward_meta.cu_seqlens_q q_4d = q.unsqueeze(0) # [1, total_tokens, H, K] @@ -705,8 +375,4 @@ def forward( ) core_attn_out = o.squeeze(0) # [total_tokens, H, V] - else: - # Fallback: pure Paddle (no pool) - core_attn_out = self.kernel_dispatcher.extend_fallback(q, k, v, g, beta) - return core_attn_out From 706eaa1d160688ee3892137dbdb77c24f54d7aa1 Mon Sep 17 00:00:00 2001 From: liukebin Date: Tue, 31 Mar 2026 21:29:59 +0800 Subject: [PATCH 5/6] [Feature] update gdn backend for qwen3.5 --- .../layers/attention/gdn_backend.py | 75 +-- fastdeploy/worker/gpu_model_runner.py | 25 +- tests/model_executor/test_gdn_backend.py | 546 ++++++++++++++++++ 3 files changed, 583 insertions(+), 63 deletions(-) create mode 100644 tests/model_executor/test_gdn_backend.py diff --git a/fastdeploy/model_executor/layers/attention/gdn_backend.py b/fastdeploy/model_executor/layers/attention/gdn_backend.py index cf61a78bb13..2edf1653c71 100644 --- a/fastdeploy/model_executor/layers/attention/gdn_backend.py +++ b/fastdeploy/model_executor/layers/attention/gdn_backend.py @@ -42,52 +42,25 @@ import paddle +from fastdeploy.cache_manager.gdn_state_pool import GDNStatePool from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) +from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import ( + causal_conv1d_fn, + causal_conv1d_update, +) +from fastdeploy.model_executor.ops.triton_ops.fla import ( + chunk_gated_delta_rule, + fused_gdn_gating, + fused_recurrent_gated_delta_rule_update, +) if TYPE_CHECKING: from fastdeploy.model_executor.forward_meta import ForwardMeta logger = logging.getLogger(__name__) - -# ============================================================================== -# Helper functions (extracted from qwen3_5.py) -# ============================================================================== - - -def l2norm(x: paddle.Tensor, axis: int = -1, eps: float = 1e-6) -> paddle.Tensor: - """L2 normalization, aligns with FLA library.""" - inv_norm = paddle.rsqrt((x * x).sum(axis=axis, keepdim=True) + eps) - return x * inv_norm - - -def fused_gdn_gating( - A_log: paddle.Tensor, - a: paddle.Tensor, - b: paddle.Tensor, - dt_bias: paddle.Tensor, -) -> tuple: - """Compute GDN gating values. - - Args: - A_log: [num_heads] - log of A matrix - a: [num_tokens, num_heads] - alpha values - b: [num_tokens, num_heads] - beta values - dt_bias: [num_heads] - delta-time bias - - Returns: - g: gating values, same shape as a - beta: sigmoid(b), same shape as b - """ - x = a.cast(paddle.float32) + dt_bias.cast(paddle.float32) - softplus_x = paddle.nn.functional.softplus(x) - g = -paddle.exp(A_log.cast(paddle.float32)) * softplus_x - beta = paddle.nn.functional.sigmoid(b.cast(paddle.float32)) - return g, beta - - # ============================================================================== # Kernel Dispatcher (strategy pattern, inspired by SGLang GDNKernelDispatcher) # ============================================================================== @@ -122,10 +95,6 @@ def decode( Returns: [batch, 1, H, DV] """ - from fastdeploy.model_executor.ops.triton_ops.fla import ( - fused_recurrent_gated_delta_rule_update, - ) - return fused_recurrent_gated_delta_rule_update( q=q, k=k, @@ -161,8 +130,6 @@ def extend( Returns: [1, total_tokens, H, DV] """ - from fastdeploy.model_executor.ops.triton_ops.fla import chunk_gated_delta_rule - # Clone initial state from pool — chunk kernel updates in-place initial_state = ssm_pool[slot_ids].clone() # [batch, H, K, V] @@ -236,20 +203,19 @@ def forward( Returns: [num_tokens, num_v_heads_local, head_v_dim] """ - from fastdeploy.cache_manager.gdn_state_pool import GDNStatePool - num_tokens = mixed_qkv.shape[0] is_decode = forward_meta.forward_mode.is_decode() # Get pool views for this layer gdn_pool = forward_meta.gdn_state_pool - raw_slot_ids = forward_meta.gdn_slot_ids + raw_slot_ids = forward_meta.gdn_slot_ids # [active_bs] — already truncated conv_pool = gdn_pool.get_layer_conv_pool(layer.gdn_layer_idx) ssm_pool = gdn_pool.get_layer_ssm_pool(layer.gdn_layer_idx) # Offset slot_ids: PAD_SLOT_ID (-1) → slot 0 slot_ids = GDNStatePool.offset_slot_ids(raw_slot_ids) + batch_size = slot_ids.shape[0] # ============================================================ # 1. Causal Conv1d @@ -257,11 +223,7 @@ def forward( conv_weight_local = layer.conv1d_weight[: layer.conv_dim] if is_decode: - from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import ( - causal_conv1d_update as triton_conv1d_update, - ) - - mixed_qkv = triton_conv1d_update( + mixed_qkv = causal_conv1d_update( x=mixed_qkv, conv_state=conv_pool, weight=conv_weight_local, @@ -270,12 +232,9 @@ def forward( conv_state_indices=slot_ids, ) else: - from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import ( - causal_conv1d_fn as triton_conv1d_fn, - ) - - cu_seqlens = forward_meta.cu_seqlens_q - mixed_qkv = triton_conv1d_fn( + # Slice cu_seqlens_q to active batch (it's a shared full buffer) + cu_seqlens = forward_meta.cu_seqlens_q[: batch_size + 1] + mixed_qkv = causal_conv1d_fn( x=mixed_qkv.T, # [dim, total_tokens] weight=conv_weight_local, bias=None, @@ -356,7 +315,7 @@ def forward( else: # Prefill/extend: chunk (Triton) - cu_seqlens = forward_meta.cu_seqlens_q + cu_seqlens = forward_meta.cu_seqlens_q[: batch_size + 1] q_4d = q.unsqueeze(0) # [1, total_tokens, H, K] k_4d = k.unsqueeze(0) v_4d = v.unsqueeze(0) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 268c675a070..a3cf906e83c 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1357,16 +1357,31 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): # Populate GDN state pool fields if available if getattr(self, "gdn_state_pool", None) is not None: self.forward_meta.gdn_state_pool = self.gdn_state_pool - # gdn_slot_ids: from share_inputs (set by scheduler via insert_tasks_v1) - self.forward_meta.gdn_slot_ids = self.share_inputs.get("gdn_slot_ids") + + # Compute active batch size: requests with seq_lens_this_time > 0. + # share_inputs tensors are pre-allocated to [max_num_seqs], but GDN + # kernels (causal_conv1d_fn, chunk_gated_delta_rule) need tensors + # sized to the actual active batch — unlike AppendAttention which + # has its own metadata to handle the full buffer. + gdn_active_bs = int((self.share_inputs["seq_lens_this_time"] > 0).sum()) + + # gdn_slot_ids: from share_inputs, truncated to active batch + raw_slot_ids = self.share_inputs.get("gdn_slot_ids") + if raw_slot_ids is not None: + self.forward_meta.gdn_slot_ids = raw_slot_ids[:gdn_active_bs] + # gdn_has_initial_state: True if request has prior state (seq_lens_decoder > 0) # For prefill of a new request, seq_lens_decoder=0 → has_initial_state=False # For decode or chunked prefill continuation, seq_lens_decoder>0 → has_initial_state=True if self.forward_meta.seq_lens_decoder is not None: - self.forward_meta.gdn_has_initial_state = self.forward_meta.seq_lens_decoder > 0 - # Derive gdn_seq_lens_cpu from seq_lens_this_time + self.forward_meta.gdn_has_initial_state = self.forward_meta.seq_lens_decoder[:gdn_active_bs] > 0 + + # Derive gdn_seq_lens_cpu from seq_lens_this_time, truncated to active batch if self.forward_meta.seq_lens_this_time is not None: - self.forward_meta.gdn_seq_lens_cpu = self.forward_meta.seq_lens_this_time.numpy().tolist() + self.forward_meta.gdn_seq_lens_cpu = ( + self.forward_meta.seq_lens_this_time[:gdn_active_bs].numpy().tolist() + ) + # GDN attention backend if getattr(self, "gdn_attn_backend", None) is not None: self.forward_meta.gdn_attn_backend = self.gdn_attn_backend diff --git a/tests/model_executor/test_gdn_backend.py b/tests/model_executor/test_gdn_backend.py new file mode 100644 index 00000000000..417f0a91050 --- /dev/null +++ b/tests/model_executor/test_gdn_backend.py @@ -0,0 +1,546 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +GDNAttentionBackend integration tests — shape + numerical correctness. + +Tests decode / prefill (extend) / mixed mode forward pass through the full +GDN backend pipeline: conv1d → split Q,K,V → gating → SSM kernel. + +Numerical correctness is verified by comparing the backend output against a +manually-constructed reference pipeline that calls the same Triton kernels +step-by-step, ensuring the backend's data orchestration (reshape, slice, +unsqueeze, GVA repeat, etc.) is correct. + +Run: + cd FastDeploy + python -m pytest tests/model_executor/ops/triton_ops/test_gdn_backend.py -v +""" + +import unittest +from enum import Enum, auto +from types import SimpleNamespace + +import numpy as np +import paddle + +from fastdeploy.cache_manager.gdn_state_pool import GDNStatePool +from fastdeploy.model_executor.layers.attention.gdn_backend import GDNAttentionBackend +from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import ( + causal_conv1d_fn, + causal_conv1d_update, +) +from fastdeploy.model_executor.ops.triton_ops.fla import ( + chunk_gated_delta_rule, + fused_gdn_gating, + fused_recurrent_gated_delta_rule_update, +) + +# ============================================================ +# Mock ForwardMode +# ============================================================ + + +class ForwardMode(Enum): + EXTEND = auto() + DECODE = auto() + MIXED = auto() + + def is_decode(self): + return self == ForwardMode.DECODE + + def is_mixed(self): + return self == ForwardMode.MIXED + + +# ============================================================ +# Test dimensions (small for speed) +# ============================================================ + +NUM_K_HEADS = 2 +NUM_V_HEADS = 4 +HEAD_K_DIM = 16 +HEAD_V_DIM = 16 +CONV_KERNEL_SIZE = 4 +NUM_V_HEADS_PER_K_HEAD = NUM_V_HEADS // NUM_K_HEADS + +KEY_DIM_LOCAL = NUM_K_HEADS * HEAD_K_DIM +VALUE_DIM_LOCAL = NUM_V_HEADS * HEAD_V_DIM +CONV_DIM = KEY_DIM_LOCAL * 2 + VALUE_DIM_LOCAL + + +# ============================================================ +# Helpers +# ============================================================ + + +def make_pool(max_num_seqs=8, num_layers=1): + return GDNStatePool( + max_num_seqs=max_num_seqs, + num_gdn_layers=num_layers, + conv_dim=CONV_DIM, + conv_kernel_size=CONV_KERNEL_SIZE, + num_v_heads=NUM_V_HEADS, + head_k_dim=HEAD_K_DIM, + head_v_dim=HEAD_V_DIM, + ) + + +def make_layer(): + layer = SimpleNamespace() + layer.gdn_layer_idx = 0 + layer.num_k_heads_local = NUM_K_HEADS + layer.num_v_heads_local = NUM_V_HEADS + layer.head_k_dim = HEAD_K_DIM + layer.head_v_dim = HEAD_V_DIM + layer.num_v_heads_per_k_head = NUM_V_HEADS_PER_K_HEAD + layer.conv_dim = CONV_DIM + layer.conv1d_weight = paddle.randn([CONV_DIM, CONV_KERNEL_SIZE], dtype="bfloat16") + layer.A_log = paddle.randn([NUM_V_HEADS], dtype="float32") + layer.dt_bias = paddle.randn([NUM_V_HEADS], dtype="float32") + return layer + + +def make_meta_decode(batch_size, pool): + raw_slot_ids = paddle.arange(0, batch_size, dtype="int32") + meta = SimpleNamespace() + meta.forward_mode = ForwardMode.DECODE + meta.gdn_state_pool = pool + meta.gdn_slot_ids = raw_slot_ids + meta.gdn_has_initial_state = paddle.ones([batch_size], dtype="bool") + meta.gdn_seq_lens_cpu = [1] * batch_size + meta.cu_seqlens_q = paddle.arange(0, batch_size + 1, dtype="int32") + return meta + + +def make_meta_extend(batch_size, seq_lens, pool): + cu = [0] + for sl in seq_lens: + cu.append(cu[-1] + sl) + raw_slot_ids = paddle.arange(0, batch_size, dtype="int32") + meta = SimpleNamespace() + meta.forward_mode = ForwardMode.EXTEND + meta.gdn_state_pool = pool + meta.gdn_slot_ids = raw_slot_ids + meta.gdn_has_initial_state = paddle.zeros([batch_size], dtype="bool") + meta.gdn_seq_lens_cpu = seq_lens + meta.cu_seqlens_q = paddle.to_tensor(cu, dtype="int32") + return meta + + +def make_meta_mixed(num_decode, num_extend, extend_seq_lens, pool): + batch_size = num_decode + num_extend + all_seq_lens = [1] * num_decode + extend_seq_lens + cu = [0] + for sl in all_seq_lens: + cu.append(cu[-1] + sl) + raw_slot_ids = paddle.arange(0, batch_size, dtype="int32") + has_initial = [True] * num_decode + [False] * num_extend + meta = SimpleNamespace() + meta.forward_mode = ForwardMode.MIXED + meta.gdn_state_pool = pool + meta.gdn_slot_ids = raw_slot_ids + meta.gdn_has_initial_state = paddle.to_tensor(has_initial, dtype="bool") + meta.gdn_seq_lens_cpu = all_seq_lens + meta.cu_seqlens_q = paddle.to_tensor(cu, dtype="int32") + return meta + + +# ============================================================ +# Reference pipeline: manually call the same Triton kernels +# ============================================================ + + +def ref_decode_pipeline(mixed_qkv, a, b, layer, conv_pool, ssm_pool, slot_ids): + """Manual decode pipeline using Triton kernels directly.""" + conv_weight = layer.conv1d_weight[: layer.conv_dim] + + # 1. Conv1d update + x = causal_conv1d_update( + x=mixed_qkv, + conv_state=conv_pool, + weight=conv_weight, + bias=None, + activation="silu", + conv_state_indices=slot_ids, + ) + + # 2. Split Q, K, V + num_tokens = x.shape[0] + q, k, v = paddle.split( + x, + [KEY_DIM_LOCAL, KEY_DIM_LOCAL, VALUE_DIM_LOCAL], + axis=-1, + ) + q = q.reshape([num_tokens, NUM_K_HEADS, HEAD_K_DIM]) + k = k.reshape([num_tokens, NUM_K_HEADS, HEAD_K_DIM]) + v = v.reshape([num_tokens, NUM_V_HEADS, HEAD_V_DIM]) + + # 3. Gating + g, beta = fused_gdn_gating( + layer.A_log[:NUM_V_HEADS], + a[:, :NUM_V_HEADS], + b[:, :NUM_V_HEADS], + layer.dt_bias[:NUM_V_HEADS], + ) + + # 4. GVA repeat + if NUM_V_HEADS_PER_K_HEAD > 1: + q = ( + q.unsqueeze(2) + .expand([num_tokens, NUM_K_HEADS, NUM_V_HEADS_PER_K_HEAD, HEAD_K_DIM]) + .reshape([num_tokens, NUM_V_HEADS, HEAD_K_DIM]) + ) + k = ( + k.unsqueeze(2) + .expand([num_tokens, NUM_K_HEADS, NUM_V_HEADS_PER_K_HEAD, HEAD_K_DIM]) + .reshape([num_tokens, NUM_V_HEADS, HEAD_K_DIM]) + ) + + # 5. SSM kernel (decode) + q_4d = q.unsqueeze(1) + k_4d = k.unsqueeze(1) + v_4d = v.unsqueeze(1) + g_4d = g.unsqueeze(1) + beta_4d = beta.unsqueeze(1) + + o = fused_recurrent_gated_delta_rule_update( + q=q_4d, + k=k_4d, + v=v_4d, + g=g_4d, + beta=beta_4d, + ssm_pool=ssm_pool, + ssm_indices=slot_ids, + use_qk_l2norm_in_kernel=True, + ) + return o.squeeze(1) # [batch, H, V] + + +def ref_extend_pipeline( + mixed_qkv, a, b, layer, conv_pool, ssm_pool, slot_ids, cu_seqlens, seq_lens_cpu, has_initial_state +): + """Manual extend pipeline using Triton kernels directly.""" + conv_weight = layer.conv1d_weight[: layer.conv_dim] + + # 1. Conv1d fn (varlen) + x = causal_conv1d_fn( + x=mixed_qkv.T, + weight=conv_weight, + bias=None, + conv_states=conv_pool, + query_start_loc=cu_seqlens, + seq_lens_cpu=seq_lens_cpu, + cache_indices=slot_ids, + has_initial_state=has_initial_state, + activation="silu", + ).T + + # 2. Split Q, K, V + num_tokens = x.shape[0] + q, k, v = paddle.split( + x, + [KEY_DIM_LOCAL, KEY_DIM_LOCAL, VALUE_DIM_LOCAL], + axis=-1, + ) + q = q.reshape([num_tokens, NUM_K_HEADS, HEAD_K_DIM]) + k = k.reshape([num_tokens, NUM_K_HEADS, HEAD_K_DIM]) + v = v.reshape([num_tokens, NUM_V_HEADS, HEAD_V_DIM]) + + # 3. Gating + g, beta = fused_gdn_gating( + layer.A_log[:NUM_V_HEADS], + a[:, :NUM_V_HEADS], + b[:, :NUM_V_HEADS], + layer.dt_bias[:NUM_V_HEADS], + ) + + # 4. GVA repeat + if NUM_V_HEADS_PER_K_HEAD > 1: + q = ( + q.unsqueeze(2) + .expand([num_tokens, NUM_K_HEADS, NUM_V_HEADS_PER_K_HEAD, HEAD_K_DIM]) + .reshape([num_tokens, NUM_V_HEADS, HEAD_K_DIM]) + ) + k = ( + k.unsqueeze(2) + .expand([num_tokens, NUM_K_HEADS, NUM_V_HEADS_PER_K_HEAD, HEAD_K_DIM]) + .reshape([num_tokens, NUM_V_HEADS, HEAD_K_DIM]) + ) + + # 5. SSM kernel (chunk) + batch_size = slot_ids.shape[0] + q_4d = q.unsqueeze(0) + k_4d = k.unsqueeze(0) + v_4d = v.unsqueeze(0) + g_4d = g.unsqueeze(0) + beta_4d = beta.unsqueeze(0) + + initial_state = ssm_pool[slot_ids].clone() + o, _h = chunk_gated_delta_rule( + q=q_4d, + k=k_4d, + v=v_4d, + g=g_4d, + beta=beta_4d, + scale=None, + initial_state=initial_state, + initial_state_indices=paddle.arange(batch_size, dtype=paddle.int32), + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True, + ) + + # Write back states + for i in range(batch_size): + sid = int(slot_ids[i]) + if sid > 0: + ssm_pool[sid] = initial_state[i] + + return o.squeeze(0) # [total_tokens, H, V] + + +# ============================================================ +# Test Cases +# ============================================================ + + +class TestGDNBackendDecodeNumerical(unittest.TestCase): + """Decode: backend vs manual reference — numerical match.""" + + def test_decode_numerical(self): + batch_size = 3 + pool = make_pool(max_num_seqs=8) + layer = make_layer() + backend = GDNAttentionBackend() + meta = make_meta_decode(batch_size, pool) + + mixed_qkv = paddle.randn([batch_size, CONV_DIM], dtype="bfloat16") + a = paddle.randn([batch_size, NUM_V_HEADS], dtype="bfloat16") + b = paddle.randn([batch_size, NUM_V_HEADS], dtype="bfloat16") + + # Clone pool + inputs so both paths see the same initial state + pool_ref = make_pool(max_num_seqs=8) + slot_ids = GDNStatePool.offset_slot_ids(meta.gdn_slot_ids) + + # Reference + ref_out = ref_decode_pipeline( + mixed_qkv.clone(), + a.clone(), + b.clone(), + layer, + pool_ref.get_layer_conv_pool(0), + pool_ref.get_layer_ssm_pool(0), + slot_ids, + ) + + # Backend + out = backend.forward(mixed_qkv.clone(), a.clone(), b.clone(), layer, meta) + + self.assertEqual(out.shape, [batch_size, NUM_V_HEADS, HEAD_V_DIM]) + np.testing.assert_allclose( + out.cast("float32").numpy(), + ref_out.cast("float32").numpy(), + rtol=1e-3, + atol=1e-3, + ) + + +class TestGDNBackendExtendNumerical(unittest.TestCase): + """Extend: backend vs manual reference — numerical match.""" + + def test_extend_single_seq(self): + seq_lens = [10] + pool = make_pool(max_num_seqs=8) + layer = make_layer() + backend = GDNAttentionBackend() + meta = make_meta_extend(1, seq_lens, pool) + + total = sum(seq_lens) + mixed_qkv = paddle.randn([total, CONV_DIM], dtype="bfloat16") + a = paddle.randn([total, NUM_V_HEADS], dtype="bfloat16") + b = paddle.randn([total, NUM_V_HEADS], dtype="bfloat16") + + pool_ref = make_pool(max_num_seqs=8) + slot_ids = GDNStatePool.offset_slot_ids(meta.gdn_slot_ids) + + ref_out = ref_extend_pipeline( + mixed_qkv.clone(), + a.clone(), + b.clone(), + layer, + pool_ref.get_layer_conv_pool(0), + pool_ref.get_layer_ssm_pool(0), + slot_ids, + meta.cu_seqlens_q.clone(), + list(meta.gdn_seq_lens_cpu), + meta.gdn_has_initial_state.clone(), + ) + + out = backend.forward(mixed_qkv.clone(), a.clone(), b.clone(), layer, meta) + + self.assertEqual(out.shape, [total, NUM_V_HEADS, HEAD_V_DIM]) + np.testing.assert_allclose( + out.cast("float32").numpy(), + ref_out.cast("float32").numpy(), + rtol=1e-3, + atol=1e-3, + ) + + def test_extend_multi_seq(self): + seq_lens = [5, 8, 3] + pool = make_pool(max_num_seqs=8) + layer = make_layer() + backend = GDNAttentionBackend() + meta = make_meta_extend(3, seq_lens, pool) + + total = sum(seq_lens) + mixed_qkv = paddle.randn([total, CONV_DIM], dtype="bfloat16") + a = paddle.randn([total, NUM_V_HEADS], dtype="bfloat16") + b = paddle.randn([total, NUM_V_HEADS], dtype="bfloat16") + + pool_ref = make_pool(max_num_seqs=8) + slot_ids = GDNStatePool.offset_slot_ids(meta.gdn_slot_ids) + + ref_out = ref_extend_pipeline( + mixed_qkv.clone(), + a.clone(), + b.clone(), + layer, + pool_ref.get_layer_conv_pool(0), + pool_ref.get_layer_ssm_pool(0), + slot_ids, + meta.cu_seqlens_q.clone(), + list(meta.gdn_seq_lens_cpu), + meta.gdn_has_initial_state.clone(), + ) + + out = backend.forward(mixed_qkv.clone(), a.clone(), b.clone(), layer, meta) + + self.assertEqual(out.shape, [total, NUM_V_HEADS, HEAD_V_DIM]) + np.testing.assert_allclose( + out.cast("float32").numpy(), + ref_out.cast("float32").numpy(), + rtol=1e-3, + atol=1e-3, + ) + + +class TestGDNBackendMixedNumerical(unittest.TestCase): + """Mixed mode: backend vs manual reference — numerical match.""" + + def test_mixed_numerical(self): + """2 decode (seqlen=1) + 1 extend (seqlen=6), all through extend path.""" + num_decode = 2 + num_extend = 1 + extend_seq_lens = [6] + + pool = make_pool(max_num_seqs=8) + layer = make_layer() + backend = GDNAttentionBackend() + meta = make_meta_mixed(num_decode, num_extend, extend_seq_lens, pool) + + total = num_decode + sum(extend_seq_lens) + mixed_qkv = paddle.randn([total, CONV_DIM], dtype="bfloat16") + a = paddle.randn([total, NUM_V_HEADS], dtype="bfloat16") + b = paddle.randn([total, NUM_V_HEADS], dtype="bfloat16") + + pool_ref = make_pool(max_num_seqs=8) + slot_ids = GDNStatePool.offset_slot_ids(meta.gdn_slot_ids) + + ref_out = ref_extend_pipeline( + mixed_qkv.clone(), + a.clone(), + b.clone(), + layer, + pool_ref.get_layer_conv_pool(0), + pool_ref.get_layer_ssm_pool(0), + slot_ids, + meta.cu_seqlens_q.clone(), + list(meta.gdn_seq_lens_cpu), + meta.gdn_has_initial_state.clone(), + ) + out = backend.forward(mixed_qkv.clone(), a.clone(), b.clone(), layer, meta) + + self.assertEqual(out.shape, [total, NUM_V_HEADS, HEAD_V_DIM]) + np.testing.assert_allclose( + out.cast("float32").numpy(), + ref_out.cast("float32").numpy(), + rtol=1e-3, + atol=1e-3, + ) + + +class TestGDNBackendStateUpdate(unittest.TestCase): + """Verify SSM/conv states persist across prefill → decode.""" + + def test_prefill_then_decode_state_persists(self): + pool = make_pool(max_num_seqs=8) + layer = make_layer() + backend = GDNAttentionBackend() + + # Step 1: Prefill seqlen=5 + meta_p = make_meta_extend(1, [5], pool) + mixed_qkv = paddle.randn([5, CONV_DIM], dtype="bfloat16") + a = paddle.randn([5, NUM_V_HEADS], dtype="bfloat16") + b = paddle.randn([5, NUM_V_HEADS], dtype="bfloat16") + out1 = backend.forward(mixed_qkv, a, b, layer, meta_p) + self.assertEqual(out1.shape, [5, NUM_V_HEADS, HEAD_V_DIM]) + + # SSM state should be non-zero after prefill (slot 0 → offset 1) + ssm_state = pool.get_layer_ssm_pool(0)[1].numpy() + self.assertFalse((ssm_state == 0).all(), "SSM state should be non-zero after prefill") + + # Conv state should be non-zero after prefill + conv_state = pool.get_layer_conv_pool(0)[1].cast("float32").numpy() + self.assertFalse((conv_state == 0).all(), "Conv state should be non-zero after prefill") + + # Step 2: Decode 1 token (same slot) + meta_d = make_meta_decode(1, pool) + meta_d.gdn_has_initial_state = paddle.to_tensor([True], dtype="bool") + out2 = backend.forward( + paddle.randn([1, CONV_DIM], dtype="bfloat16"), + paddle.randn([1, NUM_V_HEADS], dtype="bfloat16"), + paddle.randn([1, NUM_V_HEADS], dtype="bfloat16"), + layer, + meta_d, + ) + self.assertEqual(out2.shape, [1, NUM_V_HEADS, HEAD_V_DIM]) + + # SSM state should have changed after decode step + ssm_state_after = pool.get_layer_ssm_pool(0)[1].numpy() + self.assertFalse( + np.allclose(ssm_state, ssm_state_after, atol=1e-10), "SSM state should change after decode step" + ) + + def test_output_not_all_zeros(self): + """Sanity: output should not be trivially zero.""" + pool = make_pool(max_num_seqs=8) + layer = make_layer() + backend = GDNAttentionBackend() + meta = make_meta_extend(1, [8], pool) + + out = backend.forward( + paddle.randn([8, CONV_DIM], dtype="bfloat16"), + paddle.randn([8, NUM_V_HEADS], dtype="bfloat16"), + paddle.randn([8, NUM_V_HEADS], dtype="bfloat16"), + layer, + meta, + ) + self.assertFalse( + (out.cast("float32").numpy() == 0).all(), + "Backend output should not be all zeros", + ) + + +if __name__ == "__main__": + unittest.main() From 0464a16d02d69cbce970e99f4c483dd9b3ea4185 Mon Sep 17 00:00:00 2001 From: liukebin Date: Tue, 31 Mar 2026 21:40:27 +0800 Subject: [PATCH 6/6] [Feature] update gdn backend for qwen3.5 --- .../layers/attention/gdn_attention.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/fastdeploy/model_executor/layers/attention/gdn_attention.py b/fastdeploy/model_executor/layers/attention/gdn_attention.py index fc9f67a10aa..fa590f9e160 100644 --- a/fastdeploy/model_executor/layers/attention/gdn_attention.py +++ b/fastdeploy/model_executor/layers/attention/gdn_attention.py @@ -21,9 +21,9 @@ Usage:: class Qwen3_5GatedDeltaNet(nn.Layer): - def __init__(self, ...): + def __init__(self, fd_config, layer_id, ...): ... - self.gdn_attn = GDNAttention() + self.gdn_attn = GDNAttention(fd_config, layer_id) def forward(self, forward_meta, hidden_states): ... @@ -37,6 +37,8 @@ def forward(self, forward_meta, hidden_states): import paddle from paddle import nn +from fastdeploy.config import FDConfig + if TYPE_CHECKING: from fastdeploy.model_executor.forward_meta import ForwardMeta @@ -45,12 +47,18 @@ class GDNAttention(nn.Layer): """GDN (Gated Delta Network) linear attention trampoline. Mirrors the role of :class:`Attention` for softmax attention: - the model layer holds ``self.gdn_attn = GDNAttention()`` and calls - ``self.gdn_attn(mixed_qkv, a, b, self, forward_meta)`` in its forward. + the model layer holds ``self.gdn_attn = GDNAttention(fd_config, layer_id)`` + and calls ``self.gdn_attn(mixed_qkv, a, b, self, forward_meta)`` in its + forward. Internally delegates to ``forward_meta.gdn_attn_backend.forward()``. """ + def __init__(self, fd_config: FDConfig, layer_id: int) -> None: + super().__init__() + self.fd_config = fd_config + self.layer_id = layer_id + def forward( self, mixed_qkv: paddle.Tensor,