Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 229 additions & 0 deletions fastdeploy/cache_manager/gdn_state_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
GDN (Gated Delta Network) State Pool — pre-allocated GPU tensor pool
for conv and SSM states used by Qwen3.5 linear attention layers.

Design:
- Analogous to paged KV cache block pool, but each slot stores a
complete per-request state (conv state + SSM state).
- All GDN layers share a single pool object, indexed by layer_idx.
- Slot 0 is reserved as a zero-filled padding sentinel.
PAD_SLOT_ID (-1) is mapped to slot 0 via +1 offset when building
gdn_slot_ids in ForwardMeta, so reads return zero and writes are
harmless.

Pool layouts:
conv_pool: [num_gdn_layers, pool_size, conv_dim, conv_kernel_size - 1]
ssm_pool: [num_gdn_layers, pool_size, num_v_heads, head_k_dim, head_v_dim]

where pool_size = max_num_seqs + 1 (slot 0 = padding sentinel).
"""

import logging
from typing import List

import paddle

logger = logging.getLogger(__name__)

PAD_SLOT_ID = -1


class GDNSlotAllocator:
"""Lightweight CPU-only slot allocator for GDN state pool.

Used by ResourceManagerV1 on the scheduler side to manage slot IDs
without requiring paddle/GPU access. The corresponding GPU tensors
live in GDNStatePool on the worker side.

Slot 0 is reserved as a padding sentinel. Valid slots: 1..max_num_seqs.
"""

def __init__(self, max_num_seqs: int):
self.max_num_seqs = max_num_seqs
self._free_slots: List[int] = list(range(max_num_seqs, 0, -1))

def allocate(self) -> int:
"""Allocate a single slot ID.

Returns:
Allocated slot ID (1-based).

Raises:
RuntimeError: If no free slots available.
"""
if not self._free_slots:
raise RuntimeError(f"GDNSlotAllocator: no free slots (max_num_seqs={self.max_num_seqs})")
return self._free_slots.pop()
Comment on lines +54 to +69
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前 slot_id 的约定在 scheduler/worker/test 之间不一致,会导致实际运行时索引越界:GDNSlotAllocator.allocate() 返回 1..max_num_seqs,但 offset_slot_ids 直接做 +1,会把最大 slot_id 映射到 max_num_seqs+1(超出 pool_size=max_num_seqs+1 的最大下标 max_num_seqs)。建议统一约定:要么 allocator 返回 0..max_num_seqs-1 并保留 +1 offset;要么保持 allocator 1-based,但 offset_slot_ids 改成仅把 PAD_SLOT_ID(-1) 映射到 0(例如 clip 到 >=0),不要整体 +1。

Copilot uses AI. Check for mistakes.

def free(self, slot_id: int):
"""Return a slot ID to the free list.

Args:
slot_id: Slot ID to free (1-based). Slot 0 is silently ignored.
"""
if slot_id > 0:
self._free_slots.append(slot_id)

@property
def num_free_slots(self) -> int:
return len(self._free_slots)


class GDNStatePool:
"""Pre-allocated GPU tensor pool for GDN conv and SSM states.

Args:
max_num_seqs: Maximum number of concurrent sequences.
num_gdn_layers: Number of GDN (linear_attention) layers in the model.
conv_dim: TP-local convolution dimension (key_dim * 2 + value_dim) // tp_size.
conv_kernel_size: Causal conv1d kernel width (e.g. 4).
num_v_heads: TP-local number of value heads (num_v_heads // tp_size).
head_k_dim: Per-head key dimension.
head_v_dim: Per-head value dimension.
conv_dtype: Data type for conv state pool (default: bfloat16).
"""

def __init__(
self,
max_num_seqs: int,
num_gdn_layers: int,
conv_dim: int,
conv_kernel_size: int,
num_v_heads: int,
head_k_dim: int,
head_v_dim: int,
conv_dtype: paddle.dtype = paddle.bfloat16,
):
self.max_num_seqs = max_num_seqs
self.num_gdn_layers = num_gdn_layers
self.conv_dim = conv_dim
self.conv_kernel_size = conv_kernel_size
self.num_v_heads = num_v_heads
self.head_k_dim = head_k_dim
self.head_v_dim = head_v_dim

# pool_size = max_num_seqs + 1; slot 0 is the padding sentinel
pool_size = max_num_seqs + 1

# Conv state pool: [num_gdn_layers, pool_size, conv_dim, conv_kernel_size - 1]
conv_state_len = conv_kernel_size - 1
self.conv_pool = paddle.zeros(
[num_gdn_layers, pool_size, conv_dim, conv_state_len],
dtype=conv_dtype,
)

# SSM state pool: [num_gdn_layers, pool_size, num_v_heads, head_k_dim, head_v_dim]
# K-first layout matching FLA kernel native format.
# float32 for numerical stability (SSM state accumulates over many steps).
self.ssm_pool = paddle.zeros(
[num_gdn_layers, pool_size, num_v_heads, head_k_dim, head_v_dim],
dtype=paddle.float32,
)

conv_mem_mb = (num_gdn_layers * pool_size * conv_dim * conv_state_len * paddle.finfo(conv_dtype).bits // 8) / (
1024 * 1024
)
ssm_mem_mb = (num_gdn_layers * pool_size * num_v_heads * head_k_dim * head_v_dim * 4) / (1024 * 1024)
logger.info(
f"GDNStatePool allocated: "
f"conv_pool {list(self.conv_pool.shape)} ({conv_mem_mb:.1f} MB), "
f"ssm_pool {list(self.ssm_pool.shape)} ({ssm_mem_mb:.1f} MB)"
)

# Free slot list: valid slots are 1..max_num_seqs (slot 0 is sentinel)
self._free_slots: List[int] = list(range(max_num_seqs, 0, -1))

logger.info(
f"GDNStatePool: {len(self._free_slots)} free slots available " f"(slot 0 reserved as padding sentinel)"
)

def allocate(self, n: int = 1) -> List[int]:
"""Allocate n slot IDs from the free list.

Args:
n: Number of slots to allocate.

Returns:
List of allocated slot IDs (1-based, already offset for pool indexing).

Raises:
RuntimeError: If not enough free slots available.
"""
if len(self._free_slots) < n:
raise RuntimeError(f"GDNStatePool: cannot allocate {n} slots, " f"only {len(self._free_slots)} free")
allocated = [self._free_slots.pop() for _ in range(n)]
return allocated

def free(self, slot_ids: List[int]):
"""Return slot IDs to the free list and zero-out their state.

Args:
slot_ids: List of slot IDs to free (1-based pool indices).
Slot 0 (padding sentinel) is silently ignored.
"""
valid = [s for s in slot_ids if s > 0]
if not valid:
return
self.reset_slots(valid)
self._free_slots.extend(valid)

@property
def num_free_slots(self) -> int:
"""Number of currently available slots."""
return len(self._free_slots)

def get_layer_conv_pool(self, layer_idx: int) -> paddle.Tensor:
"""Get conv state pool for a specific GDN layer.

Returns:
Tensor of shape [pool_size, conv_dim, conv_kernel_size - 1]
"""
return self.conv_pool[layer_idx]

def get_layer_ssm_pool(self, layer_idx: int) -> paddle.Tensor:
"""Get SSM state pool for a specific GDN layer.

Returns:
Tensor of shape [pool_size, num_v_heads, head_k_dim, head_v_dim]
"""
return self.ssm_pool[layer_idx]

def reset_slots(self, slot_ids: List[int]):
"""Zero-out conv and SSM state for given slots across all layers.

Used when requests finish and their slots are returned to the free list.

Args:
slot_ids: List of slot indices to reset (already +1 offset applied).
"""
if not slot_ids:
return
idx = paddle.to_tensor(slot_ids, dtype=paddle.int64)
for layer_idx in range(self.num_gdn_layers):
self.conv_pool[layer_idx][idx] = 0
self.ssm_pool[layer_idx][idx] = 0

@staticmethod
def offset_slot_ids(raw_slot_ids: paddle.Tensor) -> paddle.Tensor:
"""Apply +1 offset to raw slot IDs so PAD_SLOT_ID (-1) maps to slot 0.

Args:
raw_slot_ids: [batch_size] int32, may contain PAD_SLOT_ID (-1).

Returns:
Offset slot IDs where -1 -> 0, 0 -> 1, 1 -> 2, etc.
"""
return raw_slot_ids + 1
Comment on lines +219 to +229
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前 slot_id 语义在模块内不一致:GDNSlotAllocator / GDNStatePool.allocate() 都返回 1-based(slot 0 仅哨兵),但 offset_slot_ids() 直接做 +1 会把有效 slot 整体右移并可能产生 max_num_seqs+1 的越界索引(pool_size=max_num_seqs+1 最大下标是 max_num_seqs)。建议统一约定(推荐全链路 1-based),并把 offset_slot_ids 改为仅将 PAD_SLOT_ID(-1) 映射到 0、其它值保持不变,同时同步更新顶部注释中“+1 offset”的描述。

Copilot uses AI. Check for mistakes.
Comment on lines +220 to +229
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offset_slot_ids 当前实现为 raw_slot_ids + 1,但与上游写入的 gdn_slot_id(ResourceManagerV1/insert_tasks_v1 直接透传 allocator 分配结果)不匹配,容易把合法 slot 映射到越界位置。这里建议改成只处理 PAD_SLOT_ID:-1→0,其它保持不变(或相应地把 allocator 改为 0-based 并同步所有调用方/测试)。同时更新函数注释与单测期望,避免后续误用。

Copilot uses AI. Check for mistakes.
3 changes: 3 additions & 0 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ def __init__(
self.task_type = RequestType.PREFILL
self.has_been_preempted_before = False
self.idx = None
# GDN (Gated Delta Network) SSM state slot ID in GDNStatePool.
# Allocated by ResourceManagerV1 during schedule(), freed in _free_blocks().
self.gdn_slot_id = None
self.need_prefill_tokens = self.prompt_token_ids_len
self.audio_output_token_ids = []
# extend block tables
Expand Down
38 changes: 38 additions & 0 deletions fastdeploy/engine/sched/resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import paddle

from fastdeploy import envs
from fastdeploy.cache_manager.gdn_state_pool import GDNSlotAllocator
from fastdeploy.cache_manager.multimodal_cache_manager import (
EncoderCacheManager,
ProcessorCacheManager,
Expand Down Expand Up @@ -231,6 +232,32 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l
# Scheduler-side requests that have not been moved into resource manager waiting queue yet.
self.scheduler_unhandled_request_num = 0

# GDN SSM slot allocator (None if model has no GDN layers)
self.gdn_slot_allocator = self._init_gdn_slot_allocator()

def _init_gdn_slot_allocator(self):
"""Create GDN slot allocator if the model has GDN (linear_attention) layers."""
model_config = self.config.model_config
layer_types = getattr(model_config, "layer_types", None)
if layer_types is None:
# Generate from full_attention_interval if not explicit
interval = getattr(model_config, "full_attention_interval", None)
if interval is None:
return None
num_layers = model_config.num_hidden_layers
layer_types = [
"linear_attention" if (i + 1) % interval != 0 else "full_attention" for i in range(num_layers)
]
num_gdn_layers = sum(1 for lt in layer_types if lt == "linear_attention")
if num_gdn_layers == 0:
return None
allocator = GDNSlotAllocator(self.config.scheduler_config.max_num_seqs)
llm_logger.info(
f"GDN slot allocator initialized: {num_gdn_layers} GDN layers, "
f"max_num_seqs={self.config.scheduler_config.max_num_seqs}"
)
return allocator

def allocated_slots(self, request: Request):
return len(request.block_tables) * self.config.cache_config.block_size

Expand Down Expand Up @@ -997,6 +1024,9 @@ def _allocate_decode_and_extend():
request.block_tables.extend(extra_gpu_block_ids)
self.waiting.popleft()
self.running.append(request)
# Allocate GDN SSM state slot if model has GDN layers
if self.gdn_slot_allocator is not None and request.gdn_slot_id is None:
request.gdn_slot_id = self.gdn_slot_allocator.allocate()
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
Comment on lines +1027 to 1030
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前只在 WAITING->RUNNING / PREEMPTED 重新入队时分配 gdn_slot_id;但 splitwise_role==prefill 场景下 common_engine 会通过 add_request_in_p() 直接把请求塞进 running 队列,走“running 队列的 prefill 分支”时不会触发这里的分配逻辑,导致 request.gdn_slot_id 仍为 None,worker 侧 gdn_slot_ids 可能保持为 PAD 或旧值从而污染状态。建议在请求首次进入 running 时(如 add_request_in_p()/preallocate_resource_in_p 或 schedule 的 running-prefill 分支)补充分配。

Copilot uses AI. Check for mistakes.
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
Expand Down Expand Up @@ -1062,6 +1092,9 @@ def _allocate_decode_and_extend():
request.block_tables.extend(extra_gpu_block_ids)
self.waiting.popleft()
self.running.append(request)
# Re-allocate GDN slot for preempted request
if self.gdn_slot_allocator is not None and request.gdn_slot_id is None:
request.gdn_slot_id = self.gdn_slot_allocator.allocate()
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
Expand Down Expand Up @@ -1452,6 +1485,11 @@ def _free_blocks(self, request: Request):
self.cache_manager.recycle_gpu_blocks(request.block_tables, request.request_id)
request.block_tables = []

# Free GDN SSM state slot
if self.gdn_slot_allocator is not None and request.gdn_slot_id is not None:
self.gdn_slot_allocator.free(request.gdn_slot_id)
request.gdn_slot_id = None

if request.request_id in self.using_extend_tables_req_id:
reuse_block_num = self.reuse_block_num_map[request.request_id]

Expand Down
17 changes: 16 additions & 1 deletion fastdeploy/model_executor/forward_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import logging
from dataclasses import dataclass, fields
from enum import IntEnum, auto
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import paddle

Expand Down Expand Up @@ -160,6 +160,21 @@ class ForwardMeta:

position_ids: Optional[paddle.Tensor] = None

# ============================================================
# GDN (Gated Delta Network) linear attention fields
# ============================================================
# GDN state pool object (shared across all GDN layers)
gdn_state_pool: Optional[Any] = None
# Slot indices into the GDN state pool [batch_size], int32.
# PAD_SLOT_ID=-1 requests are offset to slot 0 (zero-filled sentinel).
gdn_slot_ids: Optional[paddle.Tensor] = None
# Whether each request has prior state (False = new request) [batch_size], bool
gdn_has_initial_state: Optional[paddle.Tensor] = None
# CPU sequence lengths for causal_conv1d_fn varlen [batch_size], int32
gdn_seq_lens_cpu: Optional[List[int]] = None
# GDN attention backend (GDNAttentionBackend instance, inherits AttentionBackend)
gdn_attn_backend: Optional[Any] = None

def clear_caches(self):
"""Safely clean up the caches"""
if self.caches:
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .dsa_attention_backend import DSAAttentionBackend
from .flash_attn_backend import FlashAttentionBackend
from .flash_mask_attn_backend import FlashMaskAttentionBackend
from .gdn_attention import GDNAttention
from .mla_attention_backend import MLAAttentionBackend
from .moba_attention_backend import PlasAttentionBackend
from .native_paddle_backend import PaddleNativeAttnBackend
Expand All @@ -34,6 +35,7 @@
"FlashAttentionBackend",
"BlockAttentionBackend",
"Attention",
"GDNAttention",
"PlasAttentionBackend",
"FlashMaskAttentionBackend",
]
Loading
Loading