Skip to content

[Qwen3.5][Feature][KVCache] support gdn kv cache manager and backend for qwen3.5#7074

Open
wanderHZ wants to merge 6 commits intoPaddlePaddle:developfrom
wanderHZ:add_gdn_cache_manager
Open

[Qwen3.5][Feature][KVCache] support gdn kv cache manager and backend for qwen3.5#7074
wanderHZ wants to merge 6 commits intoPaddlePaddle:developfrom
wanderHZ:add_gdn_cache_manager

Conversation

@wanderHZ
Copy link
Copy Markdown

@wanderHZ wanderHZ commented Mar 30, 2026

Motivation

为支持 Qwen3.5 GatedDeltaNet (GDN) 线性注意力的 Prefill + Decode 全阶段推理,需要在 FastDeploy 中实现 GDN SSM state 和 conv state 的完整 cache 管理,以及独立的 GDN Attention Backend。

与标准 Softmax Attention 的 paged KV cache 不同,GDN 线性注意力的状态是 固定大小的稠密矩阵(SSM state [H, K, V] + conv state [dim, K-1]),不随序列增长。因此采用 pre-allocated tensor pool + slot allocator 方案,对标 vLLM v1 中 Mamba/SSM state 的管理思路。

本 PR 在上一阶段(Triton kernel 接入)的基础上,完成两部分工作:

  1. Cache 管理:GDN state 的全生命周期管理——分配 → 传递 → 读写 → 释放,覆盖 Prefill、Decode、Chunked Prefill、Preemption 全场景
  2. Attention Backend:新增 GDNAttentionBackend(继承 FD AttentionBackend),内部通过 GDNKernelDispatcher 策略模式路由 Triton / Paddle kernel,为模型层提供统一的 forward() 接口

Modifications

新增文件

文件 说明
fastdeploy/cache_manager/gdn_state_pool.py GDNStatePool(GPU tensor pool)+ GDNSlotAllocator(CPU-only slot 分配器)
fastdeploy/model_executor/layers/attention/gdn_backend.py GDNKernelDispatcher(策略模式路由)+ GDNAttentionBackend(继承 AttentionBackend),fused_gdn_gating 直接从 fla 包导入 (Triton-only, ~337 行)
fastdeploy/model_executor/layers/attention/gdn_attention.py GDNAttention trampoline 层 — 类比 Attention,模型层通过 self.gdn_attn = GDNAttention() 使用
tests/cache_manager/test_gdn_state_pool.py 28 个单元测试,覆盖 pool 和 allocator 的全部功能
tests/model_executor/ops/triton_ops/test_gdn_backend.py 6 个数值验证集成测试:decode/extend(单序列+多序列)/mixed 模式数值对比 + 状态持久化验证

修改文件

文件 改动 说明
fastdeploy/model_executor/forward_meta.py +5 字段 gdn_state_poolgdn_slot_idsgdn_has_initial_stategdn_seq_lens_cpugdn_attn_backend
fastdeploy/engine/request.py +1 行 新增 request.gdn_slot_id = None
fastdeploy/worker/input_batch.py +1 行 新增 gdn_slot_ids tensor [max_num_seqs] int32
fastdeploy/engine/sched/resource_manager_v1.py +3 处 _init_gdn_slot_allocator() + schedule() 中 2 处 slot 分配 + _free_blocks() 中 slot 释放
fastdeploy/worker/gpu_model_runner.py +5 处 insert_tasks_v1() 写入 gdn_slot_ids + initialize_forward_meta() 填充 GDN 字段(含 active batch 裁剪)+ _initialize_gdn_state_pool() 创建 pool + backend
fastdeploy/model_executor/layers/attention/__init__.py +2 行 新增 GDNAttention 的 import 和 __all__ 导出

架构设计

整体架构

Scheduler 侧 (CPU)                       Worker 侧 (GPU)
┌──────────────────────┐                 ┌──────────────────────────────────┐
│  GDNSlotAllocator    │                 │  GDNStatePool                    │
│  ├─ allocate() → id  │                 │  ├─ conv_pool [L,P,D,K-1]  bf16 │
│  └─ free(id)         │                 │  ├─ ssm_pool  [L,P,H,K,V]  fp32 │
│                      │    slot_id      │  └─ offset_slot_ids()            │
│  ResourceManagerV1   │ ──────────────▶ │                                  │
│  schedule() / free() │  via request    │  GDNAttentionBackend             │
└──────────────────────┘                 │  ├─ GDNKernelDispatcher          │
                                         │  │   ├─ decode()   → Triton      │
                                         │  │   └─ extend()   → Triton      │
                                         │  └─ forward(mixed_qkv, a, b, ..) │
                                         └──────────────────────────────────┘

Slot 生命周期

schedule()                → 分配 slot (gdn_slot_allocator.allocate())
  ↓
request.gdn_slot_id       → 穿越整个 scheduler → worker 传递链
  ↓
insert_tasks_v1()         → 写入 share_inputs["gdn_slot_ids"][idx]
  ↓
initialize_forward_meta() → 填充 forward_meta 的 5 个 GDN 字段
  ↓
model forward             → GDNAttention.forward() → gdn_attn_backend.forward() → GDNKernelDispatcher → pool R/W
  ↓
_free_blocks()            → 释放 slot (gdn_slot_allocator.free())

GDNAttentionBackend 内部结构

gdn_backend.py
├─ GDNAttentionBackend(AttentionBackend)
│   └─ forward(mixed_qkv, a, b, layer, forward_meta)
│       ├─ offset slot_ids (PAD -1 → 0)
│       ├─ Conv1d (decode: Triton update / prefill: Triton fn)
│       ├─ Split Q,K,V + fused_gdn_gating → g, beta
│       ├─ GVA repeat (if num_v_heads > num_k_heads)
│       └─ SSM dispatch → GDNKernelDispatcher
│
├─ GDNKernelDispatcher (策略模式)
│   ├─ decode()          → fused_recurrent_gated_delta_rule_update (Triton, in-place pool)
│   │                      输入 shape: [batch, 1, H, D] (batch-first, ssm_indices 索引)
│   └─ extend()          → chunk_gated_delta_rule (Triton, clone + writeback)
│                          输入 shape: [1, total_tokens, H, D] (varlen, cu_seqlens 分段)
│
└─ Imports
    └─ fused_gdn_gating — 直接从 fla 包导入

Active Batch 裁剪

问题: FD share_inputs 预分配 [max_num_seqs] 全量 buffer,GDN kernel 需精确 [active_bs]
解决:
  gpu_model_runner.py:  gdn_active_bs = sum(seq_lens_this_time > 0)
                        gdn_slot_ids/has_initial_state/seq_lens_cpu 裁剪到 [:active_bs]
  gdn_backend.py:       cu_seqlens = forward_meta.cu_seqlens_q[:batch_size + 1]  (局部 slice)
隔离性: 不影响 AppendAttention (零交集, cu_seqlens_q 是 slice 不修改原 tensor)

关键设计决策

决策点 选择 原因
SSM state 精度 fp32 SSM state 跨步累积,bf16 溢出风险高
Conv state 精度 bf16 与模型 hidden state 一致,无累积
Pool layout (SSM) K-first [H, K, V] 匹配 FLA Triton kernel 原生布局,避免 transpose
Slot 0 哨兵 PAD_SLOT_ID=-1 → +1 offset → slot 0 (零值) 统一处理 padding,kernel 读到零 / 写入无害
has_initial_state 推导 seq_lens_decoder > 0 避免额外 share_inputs 字段,从已有信号推导
Preemption 策略 释放 slot + 重调度时重新分配 SSM state 无法部分恢复(不同于 KV cache),必须重算
Backend 架构 单文件 GDNAttentionBackend + GDNKernelDispatcher 对比 SGLang 5 文件 3 层抽象,FD 用 1 文件 2 层更简洁
Kernel 路由 策略模式 (method dispatch) 对比 SGLang Enum + ABC 继承,直接方法路由更轻量
Backend 注入方式 forward_meta.gdn_attn_backend 通过 ForwardMeta 传递,与 forward_meta.attn_backend 模式一致
Trampoline 层 GDNAttention(nn.Layer) 独立文件 类比 Attention,模型层通过 self.gdn_attn = GDNAttention() 调用,对齐 FD 惯例

Usage or Command

1. GDNStatePool 初始化

from fastdeploy.cache_manager.gdn_state_pool import GDNStatePool

pool = GDNStatePool(
    max_num_seqs=256,
    num_gdn_layers=24,       # GDN 层数
    conv_dim=3072,            # TP-local conv 维度
    conv_kernel_size=4,       # 因果卷积宽度
    num_v_heads=8,            # TP-local value 头数
    head_k_dim=128,
    head_v_dim=128,
)
# conv_pool: [24, 257, 3072, 3]  bf16
# ssm_pool:  [24, 257, 8, 128, 128]  fp32

2. GDNSlotAllocator(Scheduler 侧)

from fastdeploy.cache_manager.gdn_state_pool import GDNSlotAllocator

allocator = GDNSlotAllocator(max_num_seqs=256)
slot_id = allocator.allocate()   # → 256 (1-based)
allocator.free(slot_id)          # 归还
allocator.free(0)                # slot 0 哨兵,静默忽略

3. 模型层使用 GDNAttention(下一阶段 PR 集成)

# qwen3_5.py GatedDeltaNet.__init__() 中:
from fastdeploy.model_executor.layers.attention import GDNAttention
self.gdn_attn = GDNAttention(fd_config, layer_id)

# qwen3_5.py GatedDeltaNet.forward() 中:
core_attn_out = self.gdn_attn(mixed_qkv, a, b, self, forward_meta)

4. Prefill / Decode 阶段行为

阶段 has_initial_state Backend 行为
Prefill(新请求) False Conv1d 从零开始 → chunk SSM 从零计算 → 写入 pool
Decode(续生成) True Conv1d update 从 pool 读写 → fused recurrent 从 pool 读写
Chunked Prefill(后续 chunk) True 读取前一 chunk 写入的 state
Preemption 后重调度 False 新 slot,从零重算(SSM state 无法部分恢复)

运行测试

cd FastDeploy
# 功能测试
python -m pytest tests/cache_manager/test_gdn_state_pool.py -v       # 28 tests
python -m pytest tests/model_executor/ops/triton_ops/test_gdn_backend.py -v  # 6 tests (数值验证)

# 回归测试
python -m pytest tests/engine/test_request.py tests/v1/test_resource_manager_v1.py tests/worker/test_gpu_model_runner.py tests/cache_manager/test_prefix_cache_manager.py tests/model_executor/test_forward_meta_str.py tests/worker/test_reorder_split_prefill_and_decode.py -v

Test Results

功能测试(34 个用例)

测试类 用例数 说明
TestGDNStatePool 10 Pool 创建、形状、哨兵、layer 视图、reset
TestGDNStatePoolAllocateFree 4 Pool 侧 allocate/free,状态清零验证
TestGDNSlotAllocator 7 CPU allocator:1-based、唯一性、耗尽异常、释放回收、哨兵忽略
TestPADSlotSentinel 3 PAD_SLOT_ID offset、哨兵零值、写入无害
TestMultiLayerConsistency 2 多层独立性
TestMemoryFootprint 2 显存占用计算
TestGDNBackendDecodeNumerical 1 Decode: backend vs 参考 pipeline 数值对比 (rtol/atol=1e-3)
TestGDNBackendExtendNumerical 2 Extend: 单序列 + 多序列数值对比
TestGDNBackendMixedNumerical 1 Mixed: 2 decode + 1 extend(6) 数值对比
TestGDNBackendStateUpdate 2 Prefill→Decode 状态持久化 + 非零验证

全部 34/34 通过。

回归测试

测试文件 用例数 结果 说明
tests/engine/test_request.py 30 ✅ 通过 验证 request.py 改动无回归
tests/v1/test_resource_manager_v1.py 33 ✅ 32 passed, 1 skipped 验证 scheduler 改动无回归
tests/worker/test_gpu_model_runner.py 19 ✅ 通过 验证 worker + backend 注入无回归
tests/cache_manager/test_prefix_cache_manager.py 56 ✅ 通过 验证 cache 管理无交叉影响
tests/model_executor/test_forward_meta_str.py 1 ✅ 通过 验证 ForwardMeta 新增字段兼容
tests/worker/test_reorder_split_prefill_and_decode.py 8 ✅ 通过 验证 input_batch 改动无回归

总计: 186 tests passed, 1 skipped, 0 failed

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[Qwen3.5][Feature][KVCache]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.

Copilot AI review requested due to automatic review settings March 30, 2026 06:22
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Mar 30, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

本 PR 旨在为 Qwen3.5 的 GDN(GatedDeltaNet)线性注意力补齐 SSM/conv state 的缓存池 + slot 管理,并把 slot_id 贯穿 scheduler→worker→forward_meta 的链路,为后续 kernel 通过 slot 访问/更新 state 做准备。

Changes:

  • 新增 GDNStatePool(GPU 预分配 tensor pool)与 GDNSlotAllocator(CPU slot 分配器),并引入 PAD 哨兵 slot 的设计。
  • Scheduler/Worker/ForwardMeta 增加并透传 gdn_slot_id(s) 相关字段,用于 Prefill/Preemption 等场景的分配与释放。
  • 新增单测 tests/cache_manager/test_gdn_state_pool.py 覆盖 pool/allocator 的核心行为。

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
fastdeploy/cache_manager/gdn_state_pool.py 新增 GDN state pool 与 slot allocator 实现
fastdeploy/engine/sched/resource_manager_v1.py scheduler 侧初始化/分配/释放 GDN slot
fastdeploy/worker/input_batch.py 增加 share_inputs 的 gdn_slot_ids 缓冲区
fastdeploy/worker/gpu_model_runner.py 插入任务时写入 gdn_slot_ids;forward_meta 填充 GDN 字段;初始化 state pool
fastdeploy/model_executor/forward_meta.py ForwardMeta 增加 GDN 相关字段
fastdeploy/engine/request.py Request 增加 gdn_slot_id 属性
tests/cache_manager/test_gdn_state_pool.py 新增 GDNStatePool / GDNSlotAllocator 单测

self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
# Write GDN slot ID for this request
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

写入 gdn_slot_ids 时仅在 request.gdn_slot_id 非 None 才赋值;一旦出现 slot 未分配/字段缺失(例如 splitwise prefill 路径或异常分支),该 batch idx 可能残留上一个请求的 slot_id,导致后续错误复用 state。建议和 block_tables 一样,prefill 分支先把该 idx 的 gdn_slot_ids 置为 PAD(-1),再根据 request.gdn_slot_id 覆盖写入。

Suggested change
# Write GDN slot ID for this request
# Write GDN slot ID for this request
# Reset current batch idx gdn_slot_ids to PAD (-1) to avoid stale values
self.share_inputs["gdn_slot_ids"][idx : idx + 1] = -1

Copilot uses AI. Check for mistakes.
Comment on lines +1367 to +1368
# Derive gdn_seq_lens_cpu from seq_lens_this_time
if self.forward_meta.seq_lens_this_time is not None:
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gdn_seq_lens_cpu 这里从 seq_lens_this_time(通常在 GPU)做 .numpy().tolist() 会产生同步的 D2H 拷贝,且每 step 都会触发一次,可能影响 decode 吞吐。InputBatch 已维护了 seq_lens_this_time_cpu(在 prepare_inputs 里 copy 过来),建议优先从 seq_lens_this_time_cpu 生成 list 以避免额外同步。

Suggested change
# Derive gdn_seq_lens_cpu from seq_lens_this_time
if self.forward_meta.seq_lens_this_time is not None:
# Derive gdn_seq_lens_cpu from CPU copy if available to avoid extra D2H sync
seq_lens_this_time_cpu = getattr(self.forward_meta, "seq_lens_this_time_cpu", None)
if seq_lens_this_time_cpu is not None:
self.forward_meta.gdn_seq_lens_cpu = seq_lens_this_time_cpu.tolist()
elif self.forward_meta.seq_lens_this_time is not None:
# Fallback: keep original behavior when CPU copy is not available

Copilot uses AI. Check for mistakes.
Comment on lines +219 to +229
@staticmethod
def offset_slot_ids(raw_slot_ids: paddle.Tensor) -> paddle.Tensor:
"""Apply +1 offset to raw slot IDs so PAD_SLOT_ID (-1) maps to slot 0.

Args:
raw_slot_ids: [batch_size] int32, may contain PAD_SLOT_ID (-1).

Returns:
Offset slot IDs where -1 -> 0, 0 -> 1, 1 -> 2, etc.
"""
return raw_slot_ids + 1
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前 slot_id 语义在模块内不一致:GDNSlotAllocator / GDNStatePool.allocate() 都返回 1-based(slot 0 仅哨兵),但 offset_slot_ids() 直接做 +1 会把有效 slot 整体右移并可能产生 max_num_seqs+1 的越界索引(pool_size=max_num_seqs+1 最大下标是 max_num_seqs)。建议统一约定(推荐全链路 1-based),并把 offset_slot_ids 改为仅将 PAD_SLOT_ID(-1) 映射到 0、其它值保持不变,同时同步更新顶部注释中“+1 offset”的描述。

Copilot uses AI. Check for mistakes.
Comment on lines +175 to +178
"""offset_slot_ids: -1->0, 0->1, 1->2, etc."""
raw = paddle.to_tensor([-1, 0, 1, 5], dtype=paddle.int32)
offset = GDNStatePool.offset_slot_ids(raw)
expected = paddle.to_tensor([0, 1, 2, 6], dtype=paddle.int32)
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 offset_slot_ids 单测按“raw 0-based +1”来断言(-1->0, 0->1, 1->2...),但同文件/PR 其余位置(GDNSlotAllocator、GDNStatePool.allocate/free、Request.gdn_slot_id)都在使用 1-based slot_id。建议在统一 slot_id 约定后同步修正该用例的输入/期望值,避免掩盖 off-by-one 错误。

Suggested change
"""offset_slot_ids: -1->0, 0->1, 1->2, etc."""
raw = paddle.to_tensor([-1, 0, 1, 5], dtype=paddle.int32)
offset = GDNStatePool.offset_slot_ids(raw)
expected = paddle.to_tensor([0, 1, 2, 6], dtype=paddle.int32)
"""offset_slot_ids should pass through 1-based slot ids unchanged."""
raw = paddle.to_tensor([1, 2, 5], dtype=paddle.int32)
offset = GDNStatePool.offset_slot_ids(raw)
expected = paddle.to_tensor([1, 2, 5], dtype=paddle.int32)

Copilot uses AI. Check for mistakes.
Comment on lines +1027 to 1030
# Allocate GDN SSM state slot if model has GDN layers
if self.gdn_slot_allocator is not None and request.gdn_slot_id is None:
request.gdn_slot_id = self.gdn_slot_allocator.allocate()
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前只在 WAITING->RUNNING / PREEMPTED 重新入队时分配 gdn_slot_id;但 splitwise_role==prefill 场景下 common_engine 会通过 add_request_in_p() 直接把请求塞进 running 队列,走“running 队列的 prefill 分支”时不会触发这里的分配逻辑,导致 request.gdn_slot_id 仍为 None,worker 侧 gdn_slot_ids 可能保持为 PAD 或旧值从而污染状态。建议在请求首次进入 running 时(如 add_request_in_p()/preallocate_resource_in_p 或 schedule 的 running-prefill 分支)补充分配。

Copilot uses AI. Check for mistakes.
@wanderHZ wanderHZ changed the title [Feature] support gdn kv cache manager for qwen3.5 [Qwen3.5][Feature][KVCache] support gdn kv cache manager and backend for qwen3.5 Mar 31, 2026
Copilot AI review requested due to automatic review settings March 31, 2026 05:17
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 6 comments.

Comment on lines +235 to +237
# GDN SSM state pool slot IDs (PAD_SLOT_ID=-1 for empty slots)
self.gdn_slot_ids = paddle.full([max_num_seqs], -1, dtype="int32")

Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里新增了 share_inputs.gdn_slot_ids,但 InputBatch.swap_states()/reset_share_inputs() 都没有同步处理该字段:

  • condense()/reorder 时会调用 swap_states(),若 gdn_slot_ids 不随 idx 交换,会导致 request 与 slot_id 映射错位,进而读写错误的 GDN state。
  • reset_share_inputs() 也未将 gdn_slot_ids 恢复为 -1,可能导致新请求复用旧 slot_id。
    建议把 gdn_slot_ids 加入 swap_states 的 swap_data 列表,并在 reset_share_inputs() 中用 fill_paddle_tensor 置为 -1。

Copilot uses AI. Check for mistakes.
Comment on lines +409 to +411
from fastdeploy.model_executor.ops.triton_ops.fla import (
fused_recurrent_gated_delta_rule_update,
)
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GDNAttentionBackend 这里依赖 fastdeploy.model_executor.ops.triton_ops.fla,但当前代码库中 fastdeploy/model_executor/ops/triton_ops/ 下没有 fla 包(导入会直接失败)。如果该 PR 依赖上一 PR (#7024) 的文件,需要把依赖的模块一并纳入本 PR/合并顺序,或在此处改为正确的实际路径并加上可用性检测(不存在时走 fallback)。

Suggested change
from fastdeploy.model_executor.ops.triton_ops.fla import (
fused_recurrent_gated_delta_rule_update,
)
try:
from fastdeploy.model_executor.ops.triton_ops.fla import (
fused_recurrent_gated_delta_rule_update,
)
except ImportError as exc:
logger.error(
"Triton FLA backend `fastdeploy.model_executor.ops.triton_ops.fla` "
"is not available. Please ensure it is built and installed correctly."
)
raise RuntimeError(
"Missing Triton FLA backend for GDNKernelDispatcher.decode()."
) from exc

Copilot uses AI. Check for mistakes.
Comment on lines +575 to +590
from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import (
causal_conv1d_update as triton_conv1d_update,
)

mixed_qkv = triton_conv1d_update(
x=mixed_qkv,
conv_state=conv_pool,
weight=conv_weight_local,
bias=None,
activation="silu",
conv_state_indices=slot_ids,
)
elif not is_decode and has_pool:
from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import (
causal_conv1d_fn as triton_conv1d_fn,
)
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里导入 fastdeploy.model_executor.ops.triton_ops.causal_conv1d,但当前代码库中 fastdeploy/model_executor/ops/triton_ops/ 下没有该模块(会在运行时 ImportError)。建议:

  1. 确认 causal_conv1d 的真实位置并修正 import;或
  2. 如果模块来自依赖 PR,确保合并顺序/将文件包含进本 PR;并在导入失败时自动切到 _causal_conv1d_fn_fallback

Copilot uses AI. Check for mistakes.
Comment on lines +557 to +563
# Get pool views for this layer
gdn_pool = forward_meta.gdn_state_pool
raw_slot_ids = forward_meta.gdn_slot_ids

conv_pool = gdn_pool.get_layer_conv_pool(layer.gdn_layer_idx) if gdn_pool is not None else None
ssm_pool = gdn_pool.get_layer_ssm_pool(layer.gdn_layer_idx) if gdn_pool is not None else None

Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GDNAttentionBackend.forward 依赖 layer 上的 gdn_layer_idx / conv1d_weight / conv_dim / num_k_heads_local / head_k_dim / ... / A_log / dt_bias 等属性。但在当前仓库中全局搜索不到这些属性的定义/赋值(除本文件外),这会导致模型 forward 直接 AttributeError。建议在 backend 与模型层之间明确一层稳定接口:

  • 要么在对应的 Qwen3.5 GDN layer 中补齐这些属性并保证命名一致;
  • 要么在这里改为读取现有模型层的真实字段名(并在缺失时给出清晰错误/回退)。

Copilot uses AI. Check for mistakes.
Comment on lines +466 to +471
# Write back final states (updated in-place by the kernel)
batch_size = slot_ids.shape[0]
for i in range(batch_size):
sid = int(slot_ids[i])
if sid > 0: # skip padding sentinel (slot 0)
ssm_pool[sid] = initial_state[i]
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extend 路径的 state writeback 这里使用 Python for-loop 并对 slot_ids[i]int(...) 转换:

  • 若 slot_ids 在 GPU 上,逐元素转 Python int 会触发同步/拷贝,甚至可能直接报错;
  • 即便在 CPU 上,逐请求循环写回也会成为 batch 大时的明显瓶颈。
    建议改为向量化写回(例如先用 mask 过滤掉 slot 0,然后一次性用 advanced indexing/scatter 将 initial_state 写回 ssm_pool)。
Suggested change
# Write back final states (updated in-place by the kernel)
batch_size = slot_ids.shape[0]
for i in range(batch_size):
sid = int(slot_ids[i])
if sid > 0: # skip padding sentinel (slot 0)
ssm_pool[sid] = initial_state[i]
# Write back final states (updated in-place by the kernel) using vectorized indexing
# Only slots with id > 0 (non-padding) are written back.
valid_mask = slot_ids > 0
if valid_mask.any():
valid_slot_ids = slot_ids[valid_mask]
ssm_pool[valid_slot_ids] = initial_state[valid_mask]

Copilot uses AI. Check for mistakes.
Comment on lines +1367 to +1369
# Derive gdn_seq_lens_cpu from seq_lens_this_time
if self.forward_meta.seq_lens_this_time is not None:
self.forward_meta.gdn_seq_lens_cpu = self.forward_meta.seq_lens_this_time.numpy().tolist()
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里每步都通过 seq_lens_this_time.numpy().tolist() 生成 gdn_seq_lens_cpu,会额外触发一次 D2H 同步/拷贝;而上游 _preprocess() 已经把 seq_lens_this_time 异步 copy 到 pinned 的 seq_lens_this_time_cpu(见同文件 1164 行)。建议直接复用 share_inputs["seq_lens_this_time_cpu"] 生成 list(或在 pre_process 阶段一次性产出所需的 CPU lens),避免重复同步造成 decode 吞吐下降。

Suggested change
# Derive gdn_seq_lens_cpu from seq_lens_this_time
if self.forward_meta.seq_lens_this_time is not None:
self.forward_meta.gdn_seq_lens_cpu = self.forward_meta.seq_lens_this_time.numpy().tolist()
# Derive gdn_seq_lens_cpu from CPU seq_lens if available to avoid extra D2H copy
if self.forward_meta.seq_lens_this_time is not None:
seq_lens_cpu = self.share_inputs.get("seq_lens_this_time_cpu")
if seq_lens_cpu is not None:
# seq_lens_this_time_cpu is prepared in _preprocess (pinned CPU tensor or list)
if isinstance(seq_lens_cpu, paddle.Tensor):
self.forward_meta.gdn_seq_lens_cpu = seq_lens_cpu.numpy().tolist()
else:
# Assume seq_lens_cpu is already on CPU (e.g. list/ndarray)
self.forward_meta.gdn_seq_lens_cpu = list(seq_lens_cpu)
else:
# Fallback: keep previous behavior and derive from seq_lens_this_time directly
self.forward_meta.gdn_seq_lens_cpu = self.forward_meta.seq_lens_this_time.numpy().tolist()

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 31, 2026 13:30
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 11 changed files in this pull request and generated 7 comments.

Comment on lines +54 to +69
def __init__(self, max_num_seqs: int):
self.max_num_seqs = max_num_seqs
self._free_slots: List[int] = list(range(max_num_seqs, 0, -1))

def allocate(self) -> int:
"""Allocate a single slot ID.

Returns:
Allocated slot ID (1-based).

Raises:
RuntimeError: If no free slots available.
"""
if not self._free_slots:
raise RuntimeError(f"GDNSlotAllocator: no free slots (max_num_seqs={self.max_num_seqs})")
return self._free_slots.pop()
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前 slot_id 的约定在 scheduler/worker/test 之间不一致,会导致实际运行时索引越界:GDNSlotAllocator.allocate() 返回 1..max_num_seqs,但 offset_slot_ids 直接做 +1,会把最大 slot_id 映射到 max_num_seqs+1(超出 pool_size=max_num_seqs+1 的最大下标 max_num_seqs)。建议统一约定:要么 allocator 返回 0..max_num_seqs-1 并保留 +1 offset;要么保持 allocator 1-based,但 offset_slot_ids 改成仅把 PAD_SLOT_ID(-1) 映射到 0(例如 clip 到 >=0),不要整体 +1。

Copilot uses AI. Check for mistakes.
Comment on lines +220 to +229
def offset_slot_ids(raw_slot_ids: paddle.Tensor) -> paddle.Tensor:
"""Apply +1 offset to raw slot IDs so PAD_SLOT_ID (-1) maps to slot 0.

Args:
raw_slot_ids: [batch_size] int32, may contain PAD_SLOT_ID (-1).

Returns:
Offset slot IDs where -1 -> 0, 0 -> 1, 1 -> 2, etc.
"""
return raw_slot_ids + 1
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offset_slot_ids 当前实现为 raw_slot_ids + 1,但与上游写入的 gdn_slot_id(ResourceManagerV1/insert_tasks_v1 直接透传 allocator 分配结果)不匹配,容易把合法 slot 映射到越界位置。这里建议改成只处理 PAD_SLOT_ID:-1→0,其它保持不变(或相应地把 allocator 改为 0-based 并同步所有调用方/测试)。同时更新函数注释与单测期望,避免后续误用。

Copilot uses AI. Check for mistakes.
Comment on lines +216 to +217
# Offset slot_ids: PAD_SLOT_ID (-1) → slot 0
slot_ids = GDNStatePool.offset_slot_ids(raw_slot_ids)
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里对 raw_slot_ids 做 offset(PAD -1→slot 0)依赖 offset_slot_ids 的约定;但当前 scheduler 侧分配的是 1-based slot_id,若 offset_slot_ids 仍是 +1 会导致 slot_ids 越界访问 conv_pool/ssm_pool。建议在修复 slot_id 约定后,同步这里的注释与逻辑,确保 slot_ids 的取值范围严格落在 [0, pool_size-1]。

Suggested change
# Offset slot_ids: PAD_SLOT_ID (-1) → slot 0
slot_ids = GDNStatePool.offset_slot_ids(raw_slot_ids)
# Normalize slot_ids: PAD_SLOT_ID (-1) should be mapped into a valid slot (typically 0)
slot_ids = GDNStatePool.offset_slot_ids(raw_slot_ids)
# Guard against out-of-range slot_ids to avoid invalid conv/ssm pool access
pool_size = conv_pool.shape[0]
invalid_mask = paddle.logical_or(slot_ids < 0, slot_ids >= pool_size)
if bool(paddle.any(invalid_mask).item()):
slot_min = int(slot_ids.min().item())
slot_max = int(slot_ids.max().item())
msg = (
f"Invalid GDN slot_ids detected: min={slot_min}, max={slot_max}, "
f"expected in [0, {pool_size - 1}]. Please check scheduler and "
"GDNStatePool.offset_slot_ids conventions."
)
logger.error(msg)
raise ValueError(msg)

Copilot uses AI. Check for mistakes.
Comment on lines +149 to +154
# Write back final states (updated in-place by the kernel)
batch_size = slot_ids.shape[0]
for i in range(batch_size):
sid = int(slot_ids[i])
if sid > 0: # skip padding sentinel (slot 0)
ssm_pool[sid] = initial_state[i]
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extend 路径写回 final state 时用 Python for 循环逐条赋值(batch_size 次),在高并发/大 batch 的 prefill/chunked prefill 下会成为明显的 CPU 侧瓶颈并引入大量小 kernel/同步。建议用向量化写回:对 slot_ids>0 做 mask 后一次性索引赋值(例如 ssm_pool[slot_ids[mask]] = initial_state[mask]),或使用 scatter/update API,避免逐元素循环。

Suggested change
# Write back final states (updated in-place by the kernel)
batch_size = slot_ids.shape[0]
for i in range(batch_size):
sid = int(slot_ids[i])
if sid > 0: # skip padding sentinel (slot 0)
ssm_pool[sid] = initial_state[i]
# Write back final states (updated in-place by the kernel).
# Use vectorized writeback to avoid Python-side per-element loops.
mask = slot_ids > 0 # skip padding sentinel (slot 0)
ssm_pool[slot_ids[mask]] = initial_state[mask]

Copilot uses AI. Check for mistakes.

import numpy as np
import paddle

Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该测试文件依赖 Triton CUDA kernels(causal_conv1d / FLA),但当前未在 setUp/setUpClass 中检查 CUDA 可用性并 skip;在 CPU-only 或未编译 CUDA 的 CI 上会直接失败。建议参照仓库其它 GPU 测试的做法:在 setUpClass 检测 paddle.is_compiled_with_cuda()/current_platform.is_cuda(),否则 raise unittest.SkipTest,并在需要时 paddle.set_device('gpu')。

Suggested change
# Skip this test module if CUDA is not available, since it depends on Triton CUDA kernels.
if not paddle.is_compiled_with_cuda():
raise unittest.SkipTest("CUDA is required to run GDNAttentionBackend Triton kernel tests.")
paddle.set_device("gpu")

Copilot uses AI. Check for mistakes.

Run:
cd FastDeploy
python -m pytest tests/model_executor/ops/triton_ops/test_gdn_backend.py -v
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文件头部的运行命令路径与实际文件路径不一致:当前文件是 tests/model_executor/test_gdn_backend.py,但注释里写的是 tests/model_executor/ops/triton_ops/test_gdn_backend.py,容易误导本地/CI 运行。建议修正为正确路径。

Suggested change
python -m pytest tests/model_executor/ops/triton_ops/test_gdn_backend.py -v
python -m pytest tests/model_executor/test_gdn_backend.py -v

Copilot uses AI. Check for mistakes.
Comment on lines +175 to +185
"""offset_slot_ids: -1->0, 0->1, 1->2, etc."""
raw = paddle.to_tensor([-1, 0, 1, 5], dtype=paddle.int32)
offset = GDNStatePool.offset_slot_ids(raw)
expected = paddle.to_tensor([0, 1, 2, 6], dtype=paddle.int32)
self.assertTrue(paddle.all(offset == expected).item())

def test_offset_slot_ids_pad_maps_to_sentinel(self):
"""PAD_SLOT_ID (-1) should map to slot 0 (the zero-filled sentinel)."""
raw = paddle.to_tensor([PAD_SLOT_ID], dtype=paddle.int32)
offset = GDNStatePool.offset_slot_ids(raw)
self.assertEqual(offset[0].item(), 0)
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该单测对 offset_slot_ids 的期望是整体 +1(-1→0, 0→1, 1→2...),但当前 allocator / 运行链路里 gdn_slot_id 是 1-based 透传到 worker,这会导致 +1 后越界。建议在统一 slot_id 约定后同步更新这里的断言(例如只验证 PAD -1→0,合法 slot 保持不变),避免单测掩盖线上 bug。

Suggested change
"""offset_slot_ids: -1->0, 0->1, 1->2, etc."""
raw = paddle.to_tensor([-1, 0, 1, 5], dtype=paddle.int32)
offset = GDNStatePool.offset_slot_ids(raw)
expected = paddle.to_tensor([0, 1, 2, 6], dtype=paddle.int32)
self.assertTrue(paddle.all(offset == expected).item())
def test_offset_slot_ids_pad_maps_to_sentinel(self):
"""PAD_SLOT_ID (-1) should map to slot 0 (the zero-filled sentinel)."""
raw = paddle.to_tensor([PAD_SLOT_ID], dtype=paddle.int32)
offset = GDNStatePool.offset_slot_ids(raw)
self.assertEqual(offset[0].item(), 0)
"""offset_slot_ids: non-PAD slot ids should remain unchanged."""
raw = paddle.to_tensor([0, 1, 5], dtype=paddle.int32)
offset = GDNStatePool.offset_slot_ids(raw)
self.assertTrue(paddle.all(offset == raw).item())
def test_offset_slot_ids_pad_maps_to_sentinel(self):
"""PAD_SLOT_ID (-1) should map to slot 0 (the zero-filled sentinel), legal slots unchanged."""
raw = paddle.to_tensor([PAD_SLOT_ID, 1, 5], dtype=paddle.int32)
offset = GDNStatePool.offset_slot_ids(raw)
expected = paddle.to_tensor([0, 1, 5], dtype=paddle.int32)
self.assertTrue(paddle.all(offset == expected).item())

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants