[Qwen3.5][Feature][KVCache] support gdn kv cache manager and backend for qwen3.5#7074
[Qwen3.5][Feature][KVCache] support gdn kv cache manager and backend for qwen3.5#7074wanderHZ wants to merge 6 commits intoPaddlePaddle:developfrom
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
本 PR 旨在为 Qwen3.5 的 GDN(GatedDeltaNet)线性注意力补齐 SSM/conv state 的缓存池 + slot 管理,并把 slot_id 贯穿 scheduler→worker→forward_meta 的链路,为后续 kernel 通过 slot 访问/更新 state 做准备。
Changes:
- 新增
GDNStatePool(GPU 预分配 tensor pool)与GDNSlotAllocator(CPU slot 分配器),并引入 PAD 哨兵 slot 的设计。 - Scheduler/Worker/ForwardMeta 增加并透传
gdn_slot_id(s)相关字段,用于 Prefill/Preemption 等场景的分配与释放。 - 新增单测
tests/cache_manager/test_gdn_state_pool.py覆盖 pool/allocator 的核心行为。
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
fastdeploy/cache_manager/gdn_state_pool.py |
新增 GDN state pool 与 slot allocator 实现 |
fastdeploy/engine/sched/resource_manager_v1.py |
scheduler 侧初始化/分配/释放 GDN slot |
fastdeploy/worker/input_batch.py |
增加 share_inputs 的 gdn_slot_ids 缓冲区 |
fastdeploy/worker/gpu_model_runner.py |
插入任务时写入 gdn_slot_ids;forward_meta 填充 GDN 字段;初始化 state pool |
fastdeploy/model_executor/forward_meta.py |
ForwardMeta 增加 GDN 相关字段 |
fastdeploy/engine/request.py |
Request 增加 gdn_slot_id 属性 |
tests/cache_manager/test_gdn_state_pool.py |
新增 GDNStatePool / GDNSlotAllocator 单测 |
| self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( | ||
| request.block_tables, dtype="int32" | ||
| ) | ||
| # Write GDN slot ID for this request |
There was a problem hiding this comment.
写入 gdn_slot_ids 时仅在 request.gdn_slot_id 非 None 才赋值;一旦出现 slot 未分配/字段缺失(例如 splitwise prefill 路径或异常分支),该 batch idx 可能残留上一个请求的 slot_id,导致后续错误复用 state。建议和 block_tables 一样,prefill 分支先把该 idx 的 gdn_slot_ids 置为 PAD(-1),再根据 request.gdn_slot_id 覆盖写入。
| # Write GDN slot ID for this request | |
| # Write GDN slot ID for this request | |
| # Reset current batch idx gdn_slot_ids to PAD (-1) to avoid stale values | |
| self.share_inputs["gdn_slot_ids"][idx : idx + 1] = -1 |
| # Derive gdn_seq_lens_cpu from seq_lens_this_time | ||
| if self.forward_meta.seq_lens_this_time is not None: |
There was a problem hiding this comment.
gdn_seq_lens_cpu 这里从 seq_lens_this_time(通常在 GPU)做 .numpy().tolist() 会产生同步的 D2H 拷贝,且每 step 都会触发一次,可能影响 decode 吞吐。InputBatch 已维护了 seq_lens_this_time_cpu(在 prepare_inputs 里 copy 过来),建议优先从 seq_lens_this_time_cpu 生成 list 以避免额外同步。
| # Derive gdn_seq_lens_cpu from seq_lens_this_time | |
| if self.forward_meta.seq_lens_this_time is not None: | |
| # Derive gdn_seq_lens_cpu from CPU copy if available to avoid extra D2H sync | |
| seq_lens_this_time_cpu = getattr(self.forward_meta, "seq_lens_this_time_cpu", None) | |
| if seq_lens_this_time_cpu is not None: | |
| self.forward_meta.gdn_seq_lens_cpu = seq_lens_this_time_cpu.tolist() | |
| elif self.forward_meta.seq_lens_this_time is not None: | |
| # Fallback: keep original behavior when CPU copy is not available |
| @staticmethod | ||
| def offset_slot_ids(raw_slot_ids: paddle.Tensor) -> paddle.Tensor: | ||
| """Apply +1 offset to raw slot IDs so PAD_SLOT_ID (-1) maps to slot 0. | ||
|
|
||
| Args: | ||
| raw_slot_ids: [batch_size] int32, may contain PAD_SLOT_ID (-1). | ||
|
|
||
| Returns: | ||
| Offset slot IDs where -1 -> 0, 0 -> 1, 1 -> 2, etc. | ||
| """ | ||
| return raw_slot_ids + 1 |
There was a problem hiding this comment.
当前 slot_id 语义在模块内不一致:GDNSlotAllocator / GDNStatePool.allocate() 都返回 1-based(slot 0 仅哨兵),但 offset_slot_ids() 直接做 +1 会把有效 slot 整体右移并可能产生 max_num_seqs+1 的越界索引(pool_size=max_num_seqs+1 最大下标是 max_num_seqs)。建议统一约定(推荐全链路 1-based),并把 offset_slot_ids 改为仅将 PAD_SLOT_ID(-1) 映射到 0、其它值保持不变,同时同步更新顶部注释中“+1 offset”的描述。
| """offset_slot_ids: -1->0, 0->1, 1->2, etc.""" | ||
| raw = paddle.to_tensor([-1, 0, 1, 5], dtype=paddle.int32) | ||
| offset = GDNStatePool.offset_slot_ids(raw) | ||
| expected = paddle.to_tensor([0, 1, 2, 6], dtype=paddle.int32) |
There was a problem hiding this comment.
这里的 offset_slot_ids 单测按“raw 0-based +1”来断言(-1->0, 0->1, 1->2...),但同文件/PR 其余位置(GDNSlotAllocator、GDNStatePool.allocate/free、Request.gdn_slot_id)都在使用 1-based slot_id。建议在统一 slot_id 约定后同步修正该用例的输入/期望值,避免掩盖 off-by-one 错误。
| """offset_slot_ids: -1->0, 0->1, 1->2, etc.""" | |
| raw = paddle.to_tensor([-1, 0, 1, 5], dtype=paddle.int32) | |
| offset = GDNStatePool.offset_slot_ids(raw) | |
| expected = paddle.to_tensor([0, 1, 2, 6], dtype=paddle.int32) | |
| """offset_slot_ids should pass through 1-based slot ids unchanged.""" | |
| raw = paddle.to_tensor([1, 2, 5], dtype=paddle.int32) | |
| offset = GDNStatePool.offset_slot_ids(raw) | |
| expected = paddle.to_tensor([1, 2, 5], dtype=paddle.int32) |
| # Allocate GDN SSM state slot if model has GDN layers | ||
| if self.gdn_slot_allocator is not None and request.gdn_slot_id is None: | ||
| request.gdn_slot_id = self.gdn_slot_allocator.allocate() | ||
| scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) |
There was a problem hiding this comment.
目前只在 WAITING->RUNNING / PREEMPTED 重新入队时分配 gdn_slot_id;但 splitwise_role==prefill 场景下 common_engine 会通过 add_request_in_p() 直接把请求塞进 running 队列,走“running 队列的 prefill 分支”时不会触发这里的分配逻辑,导致 request.gdn_slot_id 仍为 None,worker 侧 gdn_slot_ids 可能保持为 PAD 或旧值从而污染状态。建议在请求首次进入 running 时(如 add_request_in_p()/preallocate_resource_in_p 或 schedule 的 running-prefill 分支)补充分配。
| # GDN SSM state pool slot IDs (PAD_SLOT_ID=-1 for empty slots) | ||
| self.gdn_slot_ids = paddle.full([max_num_seqs], -1, dtype="int32") | ||
|
|
There was a problem hiding this comment.
这里新增了 share_inputs.gdn_slot_ids,但 InputBatch.swap_states()/reset_share_inputs() 都没有同步处理该字段:
- condense()/reorder 时会调用 swap_states(),若 gdn_slot_ids 不随 idx 交换,会导致 request 与 slot_id 映射错位,进而读写错误的 GDN state。
- reset_share_inputs() 也未将 gdn_slot_ids 恢复为 -1,可能导致新请求复用旧 slot_id。
建议把 gdn_slot_ids 加入 swap_states 的 swap_data 列表,并在 reset_share_inputs() 中用 fill_paddle_tensor 置为 -1。
| from fastdeploy.model_executor.ops.triton_ops.fla import ( | ||
| fused_recurrent_gated_delta_rule_update, | ||
| ) |
There was a problem hiding this comment.
GDNAttentionBackend 这里依赖 fastdeploy.model_executor.ops.triton_ops.fla,但当前代码库中 fastdeploy/model_executor/ops/triton_ops/ 下没有 fla 包(导入会直接失败)。如果该 PR 依赖上一 PR (#7024) 的文件,需要把依赖的模块一并纳入本 PR/合并顺序,或在此处改为正确的实际路径并加上可用性检测(不存在时走 fallback)。
| from fastdeploy.model_executor.ops.triton_ops.fla import ( | |
| fused_recurrent_gated_delta_rule_update, | |
| ) | |
| try: | |
| from fastdeploy.model_executor.ops.triton_ops.fla import ( | |
| fused_recurrent_gated_delta_rule_update, | |
| ) | |
| except ImportError as exc: | |
| logger.error( | |
| "Triton FLA backend `fastdeploy.model_executor.ops.triton_ops.fla` " | |
| "is not available. Please ensure it is built and installed correctly." | |
| ) | |
| raise RuntimeError( | |
| "Missing Triton FLA backend for GDNKernelDispatcher.decode()." | |
| ) from exc |
| from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import ( | ||
| causal_conv1d_update as triton_conv1d_update, | ||
| ) | ||
|
|
||
| mixed_qkv = triton_conv1d_update( | ||
| x=mixed_qkv, | ||
| conv_state=conv_pool, | ||
| weight=conv_weight_local, | ||
| bias=None, | ||
| activation="silu", | ||
| conv_state_indices=slot_ids, | ||
| ) | ||
| elif not is_decode and has_pool: | ||
| from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import ( | ||
| causal_conv1d_fn as triton_conv1d_fn, | ||
| ) |
There was a problem hiding this comment.
这里导入 fastdeploy.model_executor.ops.triton_ops.causal_conv1d,但当前代码库中 fastdeploy/model_executor/ops/triton_ops/ 下没有该模块(会在运行时 ImportError)。建议:
- 确认 causal_conv1d 的真实位置并修正 import;或
- 如果模块来自依赖 PR,确保合并顺序/将文件包含进本 PR;并在导入失败时自动切到
_causal_conv1d_fn_fallback。
| # Get pool views for this layer | ||
| gdn_pool = forward_meta.gdn_state_pool | ||
| raw_slot_ids = forward_meta.gdn_slot_ids | ||
|
|
||
| conv_pool = gdn_pool.get_layer_conv_pool(layer.gdn_layer_idx) if gdn_pool is not None else None | ||
| ssm_pool = gdn_pool.get_layer_ssm_pool(layer.gdn_layer_idx) if gdn_pool is not None else None | ||
|
|
There was a problem hiding this comment.
GDNAttentionBackend.forward 依赖 layer 上的 gdn_layer_idx / conv1d_weight / conv_dim / num_k_heads_local / head_k_dim / ... / A_log / dt_bias 等属性。但在当前仓库中全局搜索不到这些属性的定义/赋值(除本文件外),这会导致模型 forward 直接 AttributeError。建议在 backend 与模型层之间明确一层稳定接口:
- 要么在对应的 Qwen3.5 GDN layer 中补齐这些属性并保证命名一致;
- 要么在这里改为读取现有模型层的真实字段名(并在缺失时给出清晰错误/回退)。
| # Write back final states (updated in-place by the kernel) | ||
| batch_size = slot_ids.shape[0] | ||
| for i in range(batch_size): | ||
| sid = int(slot_ids[i]) | ||
| if sid > 0: # skip padding sentinel (slot 0) | ||
| ssm_pool[sid] = initial_state[i] |
There was a problem hiding this comment.
extend 路径的 state writeback 这里使用 Python for-loop 并对 slot_ids[i] 做 int(...) 转换:
- 若 slot_ids 在 GPU 上,逐元素转 Python int 会触发同步/拷贝,甚至可能直接报错;
- 即便在 CPU 上,逐请求循环写回也会成为 batch 大时的明显瓶颈。
建议改为向量化写回(例如先用 mask 过滤掉 slot 0,然后一次性用 advanced indexing/scatter 将initial_state写回ssm_pool)。
| # Write back final states (updated in-place by the kernel) | |
| batch_size = slot_ids.shape[0] | |
| for i in range(batch_size): | |
| sid = int(slot_ids[i]) | |
| if sid > 0: # skip padding sentinel (slot 0) | |
| ssm_pool[sid] = initial_state[i] | |
| # Write back final states (updated in-place by the kernel) using vectorized indexing | |
| # Only slots with id > 0 (non-padding) are written back. | |
| valid_mask = slot_ids > 0 | |
| if valid_mask.any(): | |
| valid_slot_ids = slot_ids[valid_mask] | |
| ssm_pool[valid_slot_ids] = initial_state[valid_mask] |
| # Derive gdn_seq_lens_cpu from seq_lens_this_time | ||
| if self.forward_meta.seq_lens_this_time is not None: | ||
| self.forward_meta.gdn_seq_lens_cpu = self.forward_meta.seq_lens_this_time.numpy().tolist() |
There was a problem hiding this comment.
这里每步都通过 seq_lens_this_time.numpy().tolist() 生成 gdn_seq_lens_cpu,会额外触发一次 D2H 同步/拷贝;而上游 _preprocess() 已经把 seq_lens_this_time 异步 copy 到 pinned 的 seq_lens_this_time_cpu(见同文件 1164 行)。建议直接复用 share_inputs["seq_lens_this_time_cpu"] 生成 list(或在 pre_process 阶段一次性产出所需的 CPU lens),避免重复同步造成 decode 吞吐下降。
| # Derive gdn_seq_lens_cpu from seq_lens_this_time | |
| if self.forward_meta.seq_lens_this_time is not None: | |
| self.forward_meta.gdn_seq_lens_cpu = self.forward_meta.seq_lens_this_time.numpy().tolist() | |
| # Derive gdn_seq_lens_cpu from CPU seq_lens if available to avoid extra D2H copy | |
| if self.forward_meta.seq_lens_this_time is not None: | |
| seq_lens_cpu = self.share_inputs.get("seq_lens_this_time_cpu") | |
| if seq_lens_cpu is not None: | |
| # seq_lens_this_time_cpu is prepared in _preprocess (pinned CPU tensor or list) | |
| if isinstance(seq_lens_cpu, paddle.Tensor): | |
| self.forward_meta.gdn_seq_lens_cpu = seq_lens_cpu.numpy().tolist() | |
| else: | |
| # Assume seq_lens_cpu is already on CPU (e.g. list/ndarray) | |
| self.forward_meta.gdn_seq_lens_cpu = list(seq_lens_cpu) | |
| else: | |
| # Fallback: keep previous behavior and derive from seq_lens_this_time directly | |
| self.forward_meta.gdn_seq_lens_cpu = self.forward_meta.seq_lens_this_time.numpy().tolist() |
| def __init__(self, max_num_seqs: int): | ||
| self.max_num_seqs = max_num_seqs | ||
| self._free_slots: List[int] = list(range(max_num_seqs, 0, -1)) | ||
|
|
||
| def allocate(self) -> int: | ||
| """Allocate a single slot ID. | ||
|
|
||
| Returns: | ||
| Allocated slot ID (1-based). | ||
|
|
||
| Raises: | ||
| RuntimeError: If no free slots available. | ||
| """ | ||
| if not self._free_slots: | ||
| raise RuntimeError(f"GDNSlotAllocator: no free slots (max_num_seqs={self.max_num_seqs})") | ||
| return self._free_slots.pop() |
There was a problem hiding this comment.
目前 slot_id 的约定在 scheduler/worker/test 之间不一致,会导致实际运行时索引越界:GDNSlotAllocator.allocate() 返回 1..max_num_seqs,但 offset_slot_ids 直接做 +1,会把最大 slot_id 映射到 max_num_seqs+1(超出 pool_size=max_num_seqs+1 的最大下标 max_num_seqs)。建议统一约定:要么 allocator 返回 0..max_num_seqs-1 并保留 +1 offset;要么保持 allocator 1-based,但 offset_slot_ids 改成仅把 PAD_SLOT_ID(-1) 映射到 0(例如 clip 到 >=0),不要整体 +1。
| def offset_slot_ids(raw_slot_ids: paddle.Tensor) -> paddle.Tensor: | ||
| """Apply +1 offset to raw slot IDs so PAD_SLOT_ID (-1) maps to slot 0. | ||
|
|
||
| Args: | ||
| raw_slot_ids: [batch_size] int32, may contain PAD_SLOT_ID (-1). | ||
|
|
||
| Returns: | ||
| Offset slot IDs where -1 -> 0, 0 -> 1, 1 -> 2, etc. | ||
| """ | ||
| return raw_slot_ids + 1 |
There was a problem hiding this comment.
offset_slot_ids 当前实现为 raw_slot_ids + 1,但与上游写入的 gdn_slot_id(ResourceManagerV1/insert_tasks_v1 直接透传 allocator 分配结果)不匹配,容易把合法 slot 映射到越界位置。这里建议改成只处理 PAD_SLOT_ID:-1→0,其它保持不变(或相应地把 allocator 改为 0-based 并同步所有调用方/测试)。同时更新函数注释与单测期望,避免后续误用。
| # Offset slot_ids: PAD_SLOT_ID (-1) → slot 0 | ||
| slot_ids = GDNStatePool.offset_slot_ids(raw_slot_ids) |
There was a problem hiding this comment.
这里对 raw_slot_ids 做 offset(PAD -1→slot 0)依赖 offset_slot_ids 的约定;但当前 scheduler 侧分配的是 1-based slot_id,若 offset_slot_ids 仍是 +1 会导致 slot_ids 越界访问 conv_pool/ssm_pool。建议在修复 slot_id 约定后,同步这里的注释与逻辑,确保 slot_ids 的取值范围严格落在 [0, pool_size-1]。
| # Offset slot_ids: PAD_SLOT_ID (-1) → slot 0 | |
| slot_ids = GDNStatePool.offset_slot_ids(raw_slot_ids) | |
| # Normalize slot_ids: PAD_SLOT_ID (-1) should be mapped into a valid slot (typically 0) | |
| slot_ids = GDNStatePool.offset_slot_ids(raw_slot_ids) | |
| # Guard against out-of-range slot_ids to avoid invalid conv/ssm pool access | |
| pool_size = conv_pool.shape[0] | |
| invalid_mask = paddle.logical_or(slot_ids < 0, slot_ids >= pool_size) | |
| if bool(paddle.any(invalid_mask).item()): | |
| slot_min = int(slot_ids.min().item()) | |
| slot_max = int(slot_ids.max().item()) | |
| msg = ( | |
| f"Invalid GDN slot_ids detected: min={slot_min}, max={slot_max}, " | |
| f"expected in [0, {pool_size - 1}]. Please check scheduler and " | |
| "GDNStatePool.offset_slot_ids conventions." | |
| ) | |
| logger.error(msg) | |
| raise ValueError(msg) |
| # Write back final states (updated in-place by the kernel) | ||
| batch_size = slot_ids.shape[0] | ||
| for i in range(batch_size): | ||
| sid = int(slot_ids[i]) | ||
| if sid > 0: # skip padding sentinel (slot 0) | ||
| ssm_pool[sid] = initial_state[i] |
There was a problem hiding this comment.
extend 路径写回 final state 时用 Python for 循环逐条赋值(batch_size 次),在高并发/大 batch 的 prefill/chunked prefill 下会成为明显的 CPU 侧瓶颈并引入大量小 kernel/同步。建议用向量化写回:对 slot_ids>0 做 mask 后一次性索引赋值(例如 ssm_pool[slot_ids[mask]] = initial_state[mask]),或使用 scatter/update API,避免逐元素循环。
| # Write back final states (updated in-place by the kernel) | |
| batch_size = slot_ids.shape[0] | |
| for i in range(batch_size): | |
| sid = int(slot_ids[i]) | |
| if sid > 0: # skip padding sentinel (slot 0) | |
| ssm_pool[sid] = initial_state[i] | |
| # Write back final states (updated in-place by the kernel). | |
| # Use vectorized writeback to avoid Python-side per-element loops. | |
| mask = slot_ids > 0 # skip padding sentinel (slot 0) | |
| ssm_pool[slot_ids[mask]] = initial_state[mask] |
|
|
||
| import numpy as np | ||
| import paddle | ||
|
|
There was a problem hiding this comment.
该测试文件依赖 Triton CUDA kernels(causal_conv1d / FLA),但当前未在 setUp/setUpClass 中检查 CUDA 可用性并 skip;在 CPU-only 或未编译 CUDA 的 CI 上会直接失败。建议参照仓库其它 GPU 测试的做法:在 setUpClass 检测 paddle.is_compiled_with_cuda()/current_platform.is_cuda(),否则 raise unittest.SkipTest,并在需要时 paddle.set_device('gpu')。
| # Skip this test module if CUDA is not available, since it depends on Triton CUDA kernels. | |
| if not paddle.is_compiled_with_cuda(): | |
| raise unittest.SkipTest("CUDA is required to run GDNAttentionBackend Triton kernel tests.") | |
| paddle.set_device("gpu") |
|
|
||
| Run: | ||
| cd FastDeploy | ||
| python -m pytest tests/model_executor/ops/triton_ops/test_gdn_backend.py -v |
There was a problem hiding this comment.
文件头部的运行命令路径与实际文件路径不一致:当前文件是 tests/model_executor/test_gdn_backend.py,但注释里写的是 tests/model_executor/ops/triton_ops/test_gdn_backend.py,容易误导本地/CI 运行。建议修正为正确路径。
| python -m pytest tests/model_executor/ops/triton_ops/test_gdn_backend.py -v | |
| python -m pytest tests/model_executor/test_gdn_backend.py -v |
| """offset_slot_ids: -1->0, 0->1, 1->2, etc.""" | ||
| raw = paddle.to_tensor([-1, 0, 1, 5], dtype=paddle.int32) | ||
| offset = GDNStatePool.offset_slot_ids(raw) | ||
| expected = paddle.to_tensor([0, 1, 2, 6], dtype=paddle.int32) | ||
| self.assertTrue(paddle.all(offset == expected).item()) | ||
|
|
||
| def test_offset_slot_ids_pad_maps_to_sentinel(self): | ||
| """PAD_SLOT_ID (-1) should map to slot 0 (the zero-filled sentinel).""" | ||
| raw = paddle.to_tensor([PAD_SLOT_ID], dtype=paddle.int32) | ||
| offset = GDNStatePool.offset_slot_ids(raw) | ||
| self.assertEqual(offset[0].item(), 0) |
There was a problem hiding this comment.
该单测对 offset_slot_ids 的期望是整体 +1(-1→0, 0→1, 1→2...),但当前 allocator / 运行链路里 gdn_slot_id 是 1-based 透传到 worker,这会导致 +1 后越界。建议在统一 slot_id 约定后同步更新这里的断言(例如只验证 PAD -1→0,合法 slot 保持不变),避免单测掩盖线上 bug。
| """offset_slot_ids: -1->0, 0->1, 1->2, etc.""" | |
| raw = paddle.to_tensor([-1, 0, 1, 5], dtype=paddle.int32) | |
| offset = GDNStatePool.offset_slot_ids(raw) | |
| expected = paddle.to_tensor([0, 1, 2, 6], dtype=paddle.int32) | |
| self.assertTrue(paddle.all(offset == expected).item()) | |
| def test_offset_slot_ids_pad_maps_to_sentinel(self): | |
| """PAD_SLOT_ID (-1) should map to slot 0 (the zero-filled sentinel).""" | |
| raw = paddle.to_tensor([PAD_SLOT_ID], dtype=paddle.int32) | |
| offset = GDNStatePool.offset_slot_ids(raw) | |
| self.assertEqual(offset[0].item(), 0) | |
| """offset_slot_ids: non-PAD slot ids should remain unchanged.""" | |
| raw = paddle.to_tensor([0, 1, 5], dtype=paddle.int32) | |
| offset = GDNStatePool.offset_slot_ids(raw) | |
| self.assertTrue(paddle.all(offset == raw).item()) | |
| def test_offset_slot_ids_pad_maps_to_sentinel(self): | |
| """PAD_SLOT_ID (-1) should map to slot 0 (the zero-filled sentinel), legal slots unchanged.""" | |
| raw = paddle.to_tensor([PAD_SLOT_ID, 1, 5], dtype=paddle.int32) | |
| offset = GDNStatePool.offset_slot_ids(raw) | |
| expected = paddle.to_tensor([0, 1, 5], dtype=paddle.int32) | |
| self.assertTrue(paddle.all(offset == expected).item()) |
Motivation
为支持 Qwen3.5 GatedDeltaNet (GDN) 线性注意力的 Prefill + Decode 全阶段推理,需要在 FastDeploy 中实现 GDN SSM state 和 conv state 的完整 cache 管理,以及独立的 GDN Attention Backend。
与标准 Softmax Attention 的 paged KV cache 不同,GDN 线性注意力的状态是 固定大小的稠密矩阵(SSM state
[H, K, V]+ conv state[dim, K-1]),不随序列增长。因此采用 pre-allocated tensor pool + slot allocator 方案,对标 vLLM v1 中 Mamba/SSM state 的管理思路。本 PR 在上一阶段(Triton kernel 接入)的基础上,完成两部分工作:
GDNAttentionBackend(继承 FDAttentionBackend),内部通过GDNKernelDispatcher策略模式路由 Triton / Paddle kernel,为模型层提供统一的forward()接口Modifications
新增文件
fastdeploy/cache_manager/gdn_state_pool.pyfastdeploy/model_executor/layers/attention/gdn_backend.pyfused_gdn_gating直接从 fla 包导入 (Triton-only, ~337 行)fastdeploy/model_executor/layers/attention/gdn_attention.pyAttention,模型层通过self.gdn_attn = GDNAttention()使用tests/cache_manager/test_gdn_state_pool.pytests/model_executor/ops/triton_ops/test_gdn_backend.py修改文件
fastdeploy/model_executor/forward_meta.pygdn_state_pool、gdn_slot_ids、gdn_has_initial_state、gdn_seq_lens_cpu、gdn_attn_backendfastdeploy/engine/request.pyrequest.gdn_slot_id = Nonefastdeploy/worker/input_batch.pygdn_slot_idstensor[max_num_seqs]int32fastdeploy/engine/sched/resource_manager_v1.py_init_gdn_slot_allocator()+schedule()中 2 处 slot 分配 +_free_blocks()中 slot 释放fastdeploy/worker/gpu_model_runner.pyinsert_tasks_v1()写入 gdn_slot_ids +initialize_forward_meta()填充 GDN 字段(含 active batch 裁剪)+_initialize_gdn_state_pool()创建 pool + backendfastdeploy/model_executor/layers/attention/__init__.pyGDNAttention的 import 和__all__导出架构设计
整体架构
Slot 生命周期
GDNAttentionBackend 内部结构
Active Batch 裁剪
关键设计决策
[H, K, V]has_initial_state推导seq_lens_decoder > 0forward_meta.gdn_attn_backendforward_meta.attn_backend模式一致GDNAttention(nn.Layer)独立文件Attention,模型层通过self.gdn_attn = GDNAttention()调用,对齐 FD 惯例Usage or Command
1. GDNStatePool 初始化
2. GDNSlotAllocator(Scheduler 侧)
3. 模型层使用 GDNAttention(下一阶段 PR 集成)
4. Prefill / Decode 阶段行为
has_initial_stateFalseTrueTrueFalse运行测试
Test Results
功能测试(34 个用例)
TestGDNStatePoolTestGDNStatePoolAllocateFreeTestGDNSlotAllocatorTestPADSlotSentinelTestMultiLayerConsistencyTestMemoryFootprintTestGDNBackendDecodeNumericalTestGDNBackendExtendNumericalTestGDNBackendMixedNumericalTestGDNBackendStateUpdate全部 34/34 通过。
回归测试
tests/engine/test_request.pytests/v1/test_resource_manager_v1.pytests/worker/test_gpu_model_runner.pytests/cache_manager/test_prefix_cache_manager.pytests/model_executor/test_forward_meta_str.pytests/worker/test_reorder_split_prefill_and_decode.py总计: 186 tests passed, 1 skipped, 0 failed
Checklist
[Qwen3.5][Feature][KVCache]]pre-commitbefore commit.