Skip to content

[Feature][KVCache] Implement Cache Manager V1 with multi-GPU support and MTP KV Cache#7097

Open
kevincheng2 wants to merge 16 commits intoPaddlePaddle:developfrom
kevincheng2:feature/cache_manager_v1_rebase
Open

[Feature][KVCache] Implement Cache Manager V1 with multi-GPU support and MTP KV Cache#7097
kevincheng2 wants to merge 16 commits intoPaddlePaddle:developfrom
kevincheng2:feature/cache_manager_v1_rebase

Conversation

@kevincheng2
Copy link
Copy Markdown
Collaborator

@kevincheng2 kevincheng2 commented Mar 31, 2026

Motivation

实现全新的 Cache Manager V1 三级缓存管理架构(fastdeploy/cache_manager/v1/),规范 Cache 模块组织结构,通过全异步换入换出与逐层换入计算传输重叠优化首 token 时延,同时提高 Cache 模块的扩展性和可维护性。

Key Features

全异步换入换出(Fully Async Transfer)

换入(H2D)和换出(D2H)均为全异步操作,基于独立 CuPy CUDA stream 实现,不阻塞推理主流程:

  • _output_stream(D2H 换出):fire-and-forget 模式,换出提交后立即返回,后台异步完成
  • _input_stream(H2D 换入):逐层异步提交,每层换入完成后通过 CUDA Event 通知前向计算线程

双流并行(Dual-Stream Parallelism)

换入流(_input_stream)与换出流(_output_stream)相互独立,无需互相等待,换入与换出可同时进行,进一步提升 NVLink/PCIe 带宽利用率。

逐层换入 + 计算传输重叠(Layer-by-layer H2D with Compute-Transfer Overlap)

H2D 换入采用逐层(layer-by-layer)模式:每完成一层的 H2D kernel 提交,即在 _input_stream 上记录 CUDA Event,通知前向线程可以开始使用该层数据,从而实现换入传输与 GPU 前向计算的流水线重叠,隐藏换入延迟。

D2H 换出采用全层一次性(all-layers at once)模式,在 _output_stream 上 fire-and-forget,不影响换入与推理。

CUDA Event 精确同步(No Busy-Polling)

LayerDoneCounter 基于 CUDA Event 实现逐层完成状态跟踪,前向线程通过 wait_for_layer(layer_idx) 同步具体层的 CUDA Event,避免 busy-polling,CPU 开销极低。

Modifications

新增核心模块(fastdeploy/cache_manager/v1/)

模块 说明
base.py KVCacheBase 抽象基类,统一初始化 model_configcache_configquant_configparallel_config
block_pool.py BlockPool,管理 Device/Host block 的分配与释放,支持 LRU 驱逐
cache_controller.py CacheController,三级缓存 block 状态机(DEVICE / HOST / SWAP_TO_HOST / SWAP_TO_DEVICE / LOADING_FROM_STORAGE / DELETING),协调 swap/backup/storage 操作
cache_manager.py CacheManager,对外统一接口,封装 prefix cache 匹配、block 分配/释放、swap 调度
cache_utils.py get_block_hash_extra_keys:计算包含多模态 token 的 block hash extra keys;LayerDoneCounter:基于 CUDA Event 追踪逐层传输完成状态,支持 Compute-Transfer Overlap
metadata.py 完整元数据定义:BlockNodeMatchResultCacheSwapMetadataPDTransferMetadataStorageMetadataTransferResult
radix_tree.py RadixTree,基于 Radix Tree 的 prefix cache 索引,支持多种 host 备份策略
transfer_manager.py CacheTransferManager,管理 Host↔Device 异步 swap,基于独立 CuPy stream 实现 Compute-Transfer Overlap

CUDA 算子更新(custom_ops/gpu_ops/)

  • swap_cache_optimized.cu:单层同步/异步 swap 算子,以及多层 swap 算子;连续 block 使用 cudaMemcpyAsync 快速路径,非连续 block 使用 warp-level PTX non-temporal 访存
  • 移除 swap_cache_all_layers_batch,简化为统一的 swap_cache_all_layers

BatchRequest 封装(fastdeploy/engine/request.py)

新增 BatchRequest 类,将原来的 List[Request] 升级为携带 Cache 调度元数据的批次对象:

  • requests:当前 batch 的所有 Request
  • cache_swap_metadata:batch 级别的合并换入元数据(CacheSwapMetadata,Host→Device),由各 Request 的 swap metadata 合并而来,调度器只需下发一次换入指令
  • cache_evict_metadata:batch 级别的合并换出元数据(CacheSwapMetadata,Device→Host)

BatchRequest 同时支持序列化(__getstate__/__setstate__),可跨进程传递(Worker 侧反序列化后直接使用)。

MTP KV Cache 支持

支持 cache_manager_v1 下的 MTP(Multi-Token Prediction)KV Cache 初始化,以及多模态 token 的 block hash 计算。

测试覆盖

新增/扩展单元测试:test_cache_utils.pytest_cache_controller.pytest_cache_manager.pytest_radix_tree.pytest_transfer_manager.pytest_request.py

本 PR 已验证以下场景的端到端功能:

场景 状态
GPU Cache(Device only) ✅ 已测试
GPU + CPU Cache(Device + Host swap) ✅ 已测试
CUDAGraph ✅ 已测试
MTP(Multi-Token Prediction)KV Cache ✅ 已测试
多模态(VLM)block hash ✅ 已测试
Tensor Parallel(TP) ✅ 已测试
Storage(三级缓存持久化,fastdeploy/cache_manager/v1/storage/ ⏳ 未测试,后续 PR 支持
PD 分离传输(fastdeploy/cache_manager/v1/transfer/ ⏳ 未测试,后续 PR 支持

注意fastdeploy/cache_manager/v1/storage/(三级缓存 Storage 场景)和 fastdeploy/cache_manager/v1/transfer/(PD 分离传输场景)相关功能已完成接口定义,但本 PR 尚未测试,实现均以 TODO 占位。完整的 Storage 和 PD 传输功能将在后续独立 PR 中完成并验证。

Usage or Command

开启 Cache Manager V1

通过环境变量 ENABLE_V1_KVCACHE_MANAGER 启用新架构(默认关闭,0):

export ENABLE_V1_KVCACHE_MANAGER=1

重新编译算子

本 PR 修改了 .cu 算子文件,必须重新编译算子

# 详情参考:https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/get_started/installation/nvidia_gpu.md#4-build-wheel-from-source
bash build.sh 1 python false [90] 1

启动服务

GPU cache 配置与之前完全一致,无需额外修改。如需开启 CPU host cache,与之前一样通过 --swap-space 指定(单位 GiB):

ENABLE_V1_KVCACHE_MANAGER=1 python -m fastdeploy.entrypoints.openai.api_server \
    --model /path/to/model \
    --swap-space 16

新增参数

参数 默认值 可选值 说明
--write-policy write_through_selective write_throughwrite_through_selectivewrite_back Host 备份策略。write_through:每次 block 完成后立即备份到 host;write_through_selective:block 命中次数达到阈值后才备份(节省 host 带宽);write_back:block 被驱逐时才写回 host
--write-through-threshold 2 正整数 write_through_selective 策略下触发备份所需的最小命中次数,仅在 --write-policy write_through_selective 时生效

示例:命中 3 次后才备份到 host:

ENABLE_V1_KVCACHE_MANAGER=1 python -m fastdeploy.entrypoints.openai.api_server \
    --model /path/to/model \
    --swap-space 16 \
    --write-policy write_through_selective \
    --write-through-threshold 3

Checklist

  • Add at least a tag in the PR title.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

kevincheng2 and others added 11 commits March 30, 2026 21:04
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When a node transitions from SWAP_TO_DEVICE to DEVICE via
complete_swap_to_device, it was not being added to the
_evictable_device set. This caused nodes with ref_count=0 to
become "orphaned" - not appearing in any evictable set despite
having cache_status=DEVICE.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add new cache_manager.py with cache management functionality
- Add radix_tree.py for prefix caching
- Update block_pool.py and metadata.py
- Update request.py and resource_manager_v1.py for scheduling
- Update gpu_model_runner.py for GPU model execution

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add CacheController class for cache management
- Update config.py with cache related configurations
- Refactor gpu_model_runner.py for improved cache handling
…Meta

## Motivation

修复 swap_cache_optimized.cu 中 H2D 方向时 src/dst block_ids 使用错误的问题,
并清理 ForwardMeta 中已废弃的 cache_controller 字段。

## Modifications

- fix: swap_cache_optimized.cu 中根据 D2H 模板参数正确选取 src/dst block_ids,
  修复 H2D 方向 src/dst 倒置 bug(同时修复 SwapCachePerLayerImpl 和 SwapCacheAllLayersBatchImpl)
- refactor: cache_manager/v1/__init__.py 将 LayerSwapTimeoutError 导入从
  cache_controller 改为 cache_utils(正确来源)
- refactor: ForwardMeta 移除废弃的 cache_controller 字段
- refactor: gpu_model_runner.py 移除对应的 cache_controller 赋值语句
- test: 新增 tests/cache_manager/v1/test_swap_cache_ops.py 单元测试

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
## Motivation

对 cache manager v1 进行重构和优化,精简代码结构,提升可维护性。

## Modifications

- 重构 transfer_manager.py,大幅精简代码逻辑
- 优化 swap_cache_optimized.cu GPU 算子实现
- 调整 cache_manager.py、cache_controller.py 逻辑,修复 free_device_blocks 方法缺失问题
- 更新 block_pool.py、cache_utils.py、metadata.py、radix_tree.py
- 精简 gpu_model_runner.py、forward_meta.py、attention.py 中相关调用
- 更新对应单元测试(test_cache_controller、test_swap_cache_ops、test_transfer_manager)
- 调整 config.py 中相关配置项
## Motivation

在 enable_cache_manager_v1 路径下,MTP(speculative decode)的 KV Cache 需要由
CacheController 统一管理,以复用 swap/transfer 能力,同时修复多模态场景下 block
hash 未携带 multimodal extra_keys 的问题。

## Modifications

- `cache_controller.py`
  - 新增 `initialize_mtp_kv_cache`:通过 CacheController 初始化 MTP KV Cache,
    并将其注册到 cache_kvs_map,使 transfer_manager 自动覆盖 MTP 层
  - `initialize_host_cache` 中的 num_layers 改为包含 MTP 额外 cache 层数,保证
    Host Cache 也为 MTP 分配足够空间
  - `_free_gpu_cache` 改名为 `free_gpu_cache`(对外可调用)

- `cache_utils.py`
  - 新增 `get_block_hash_extra_keys`:提取单个 block 内的多模态 hash 信息,
    对齐 PrefixCacheManager 的 multimodal extra_keys 逻辑
  - `get_request_block_hasher` 中在 hash_block_tokens 时携带 extra_keys,
    修复多模态场景 prefix cache 命中率不准的问题

- `spec_decode/mtp.py`
  - `update_mtp_block_num` 新增 `skip_cache_init` 参数,避免 v1 cache manager
    路径下重复初始化 MTP KV Cache

- `gpu_model_runner.py`
  - `initialize_kv_cache(v1)` 路径:在主模型 cache 初始化后,调用
    `cache_controller.initialize_mtp_kv_cache` 完成 MTP cache 创建
  - `clear_cache` / `wakeup` / `reset` 等路径:respect `enable_cache_manager_v1`
    标志,跳过重复的 proposer.initialize_kv_cache 调用

## Usage or Command

```bash
# 启动支持 MTP + cache_manager_v1 的推理服务(示例)
bash run.sh
```
…atch ops

1. Fix CuPy stream/event creation for multi-GPU: wrap all stream operations
   with cp.cuda.Device(device_id) context to ensure streams/events are bound
   to the correct device, preventing cross-device errors in multi-GPU setups.

2. Remove cudaSetDevice from SwapCacheAllLayers (handled by cupy context now).

3. Remove swap_cache_all_layers_batch op: simplified the implementation by
   removing the batch upload variant; all-layer transfers now use the standard
   swap_cache_all_layers with cupy device context.

4. Fix mm hash boundary comparison in get_block_hash_extra_keys: change
   strict less-than (<) to less-than-or-equal (<=) so that multimodal items
   ending exactly at block start are correctly excluded.

5. Extract config fields to KVCacheBase: model_config, cache_config,
   quant_config, parallel_config are now set in the base class __init__ to
   avoid duplication in CacheController and CacheManager subclasses.

6. Translate metadata.py docstrings from Chinese to English for broader
   contributor accessibility.

7. Add test_cache_utils.py: comprehensive unit tests for
   get_block_hash_extra_keys covering all boundary and overlap scenarios.

8. Expand test suite: test_request.py cache fields tests, test_radix_tree.py
   backup candidate tests, test_transfer_manager.py and test_cache_manager.py
   multi-GPU and concurrent operation tests.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Mar 31, 2026

Thanks for your contribution!

… to CacheManager

## Motivation

修复两处问题:
1. `fastdeploy/engine/request.py` 中 `List` 未导入导致 pre-commit F821 报错
2. `write_policy` 归一化逻辑(`write_through` → `write_through_selective`)不应放在 `FDConfig`,移至 `CacheManager.__init__` 中,使其只影响 Cache Manager V1 的内部逻辑

## Modifications

- `fastdeploy/engine/request.py`: 在 `typing` 导入中补充 `List`,删除重复的 `CacheSwapMetadata` TYPE_CHECKING 导入,修复 F821/F811
- `fastdeploy/config.py`: 删除 `write_policy` 归一化逻辑
- `fastdeploy/cache_manager/v1/cache_manager.py`: 将归一化逻辑移入 `CacheManager.__init__`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
## Motivation

修复 CI pre-commit 代码风格检查失败问题。

## Modifications

- `fastdeploy/engine/common_engine.py`: black 格式化
- `fastdeploy/worker/worker_process.py`: black 格式化 + isort 修复
- `fastdeploy/cache_manager/v1/storage/__init__.py`: isort 修复
- `fastdeploy/worker/gpu_worker.py`: isort 修复

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
## Motivation

更新 Cache Manager V1 相关模块,完善版权信息、改进模块结构与可维护性。

## Modifications

- `fastdeploy/cache_manager/v1/` 系列模块:补充版权 header,优化代码结构
- `fastdeploy/config.py`:配置项更新
- `fastdeploy/engine/sched/resource_manager_v1.py`:调度相关更新

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…sk parsing

## Motivation

将 worker_process 中重复的 task 解析逻辑收敛到 BatchRequest,减少代码冗余,提升可维护性。

## Modifications

- `fastdeploy/engine/request.py`:新增 `BatchRequest.from_tasks()` 类方法,统一将 task_queue 任务分类为推理请求和控制请求
- `fastdeploy/worker/worker_process.py`:使用 `BatchRequest.from_tasks()` 替代内联解析逻辑,并修复重复的 control_reqs 处理块

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…he tests

## Motivation

优化 Host cache 内存分配的 NUMA 亲和性,减少跨 NUMA 访问延迟;
同时跳过 swap cache ops 测试(当前环境不支持)。

## Modifications

- `fastdeploy/cache_manager/v1/cache_controller.py`:
  - 新增 `_get_numa_node_for_gpu()` 方法,通过 nvidia-smi 或 sysfs 获取 GPU 对应的 NUMA 节点
  - 新增 `_bind_to_closest_numa_node()` 方法,绑定当前线程到 GPU 最近的 NUMA 节点
  - 在 `initialize_host_cache()` 中调用 NUMA 绑定,优化 H2D 传输性能
- `tests/cache_manager/v1/test_swap_cache_ops.py`:跳过所有测试类(`TestSwapCacheAllLayersCorrectness`、`TestSwapCacheAllLayersPerformance`、`TestSwapCacheRandomBlockIndices`)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 52.50965% with 1230 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@6727df8). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/cache_manager/v1/cache_controller.py 33.50% 243 Missing and 9 partials ⚠️
fastdeploy/cache_manager/v1/cache_manager.py 48.76% 182 Missing and 26 partials ⚠️
fastdeploy/cache_manager/v1/transfer_manager.py 54.13% 112 Missing and 10 partials ⚠️
...tdeploy/cache_manager/v1/transfer/ipc/connector.py 0.00% 96 Missing ⚠️
fastdeploy/cache_manager/v1/cache_utils.py 63.72% 63 Missing and 15 partials ⚠️
...loy/cache_manager/v1/storage/mooncake/connector.py 0.00% 62 Missing ⚠️
...deploy/cache_manager/v1/transfer/rdma/connector.py 0.00% 61 Missing ⚠️
...oy/cache_manager/v1/storage/attnstore/connector.py 0.00% 58 Missing ⚠️
fastdeploy/cache_manager/v1/storage/__init__.py 15.15% 53 Missing and 3 partials ⚠️
fastdeploy/cache_manager/v1/transfer/__init__.py 17.54% 43 Missing and 4 partials ⚠️
... and 21 more
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7097   +/-   ##
==========================================
  Coverage           ?   72.28%           
==========================================
  Files              ?      423           
  Lines              ?    59040           
  Branches           ?     9323           
==========================================
  Hits               ?    42678           
  Misses             ?    13349           
  Partials           ?     3013           
Flag Coverage Δ
GPU 72.28% <52.50%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants