[XPU] Fix mtp cudagraph#7864
Conversation
|
Thanks for your contribution! |
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 Paddle-CI-Agent | pr_review |
2026-05-20 16:29:30
📋 Review 摘要
PR 概述:修复 XPU 平台 MTP(Multi-Token Prediction)投机解码在 CUDAGraph 模式下的问题,并新增 build_sampling_params XPU 自定义算子。
变更范围:custom_ops/xpu_ops/(新增算子)、fastdeploy/spec_decode/mtp_xpu.py、fastdeploy/worker/xpu_model_runner.py、fastdeploy/model_executor/layers/sample/sampler.py
影响面 Tag:[XPU] [Speculative Decoding] [OP]
问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 🔴 Bug | fastdeploy/model_executor/layers/sample/sampler.py:1195 |
sampling_metadata.topp_seed 属性不存在,普通 XPU 采样路径运行时必抛 AttributeError |
| 📝 PR 规范 | — | PR 描述模板 Motivation / Modifications / Usage / Accuracy Tests 均为空 |
📝 PR 规范检查
PR 标题 [XPU] Fix mtp cudagraph 使用了合法 Tag,格式合规。
PR 描述中 ## Motivation、## Modifications、## Usage or Command、## Accuracy Tests 四个必填段落均为空(仅保留了模板注释占位符),Checklist 全部未勾选,不符合描述模板要求。
标题建议(可直接复制):
[XPU][BugFix] Fix MTP CUDAGraph and add build_sampling_params XPU op
PR 描述建议(可直接复制):
## Motivation
修复 XPU 平台 MTP(Multi-Token Prediction)投机解码在 CUDAGraph 模式下运行时出现的问题,主要包括:
1. `cu_seqlens_q_output` / `batch_id_per_token_output` 在 cudagraph 模式下应使用 `copy_` 原地更新而非重新赋值,否则破坏 graph 捕获的 tensor 地址稳定性。
2. 原 `padding_sampling_params`(Python 侧 CPU 实现)被替换为新增的 `build_sampling_params` XPU 自定义算子,同时在算子内部完成 `infer_seed` 的原地更新,以避免在 cudagraph 外额外操作。
3. MTP draft model 当前阶段禁用 cudagraph(TODO),并修复相关 capture 流程。
## Modifications
- `custom_ops/xpu_ops/src/ops/mtp/build_sampling_params.cc`:新增 Paddle 自定义算子入口,注册 `build_sampling_params` op。
- `custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h`:声明 `build_sampling_params` C 接口。
- `custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/build_sampling_params.xpu`:新增 XPU3 kernel 实现。
- `custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/build_sampling_params.cpp`:新增 CPU wrapper 和 XPU3 wrapper。
- `custom_ops/xpu_ops/test/test_build_sampling_params.py`:新增单元测试,覆盖纯 decoder、纯 encoder、混合、单条、seed wrap-around 等场景。
- `fastdeploy/model_executor/layers/sample/sampler.py`:`_verify_and_sample_xpu` 改用 `build_sampling_params` XPU 算子;`forward_xpu` 新增 `increment_value` 参数。
- `fastdeploy/model_executor/xpu_pre_and_post_process.py`:cudagraph 模式下改用 `copy_` 原地更新 `cu_seqlens_q_output` 和 `batch_id_per_token_output`。
- `fastdeploy/spec_decode/mtp_xpu.py`:draft model 当前强制禁用 cudagraph(TODO),`_propose` 新增 cudagraph padding 逻辑。
- `fastdeploy/worker/xpu_model_runner.py`:`increment_value` 改为与投机解码 token 数联动;投机解码模式下 `infer_seed` 更新移入 `build_sampling_params` 算子内部。
## Usage or Command
N/A
## Accuracy Tests
N/A
## Checklist
- [x] Add at least a tag in the PR title.
- Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
- You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.总体评价
本 PR 新增 build_sampling_params XPU 算子并修复 MTP cudagraph 相关问题,整体方向正确,但存在一个 P0 Bug:普通 XPU 采样路径(_normal_sample_xpu)引用了不存在的 sampling_metadata.topp_seed 属性,会在非投机解码的正常推理时直接崩溃,需在合入前修复。
| top_k=sampling_metadata.top_k, | ||
| top_k_list=sampling_metadata.top_k_list, | ||
| topp_seed=topp_seed, | ||
| topp_seed=sampling_metadata.topp_seed, |
There was a problem hiding this comment.
🔴 Bug sampling_metadata.topp_seed 属性不存在,运行时会抛出 AttributeError。
SamplingMetadata 数据类中只有 seed 字段,没有 topp_seed 字段。此处直接传 sampling_metadata.topp_seed 会在普通 XPU 采样路径(非投机解码)下必然崩溃。
原代码通过 padding_sampling_params(...) 将 seed 展开为 topp_seed(shape [token_num, 1])后再传入。本次 PR 对投机解码路径(_verify_and_sample_xpu)已正确改为调用 build_sampling_params 生成 topp_seed,但普通采样路径遗漏了同等处理。
建议修复方案(二选一):
- 在调用
top_k_top_p_sampling之前,先调用build_sampling_params生成 padded seed,与投机解码路径保持一致; - 或者在
SamplingMetadata中新增topp_seed字段并在_build_sampling_metadata中填充。
示例(方案1):
_, topp_seed = build_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
token_num_output_cpu=int(share_inputs["cu_seqlens_q_output"][-1]),
increment_value=increment_value,
)[1:] # 取 top_k_padding, topp_seed
_, next_tokens = top_k_top_p_sampling(
probs,
top_p=sampling_metadata.top_p,
top_k=sampling_metadata.top_k,
top_k_list=sampling_metadata.top_k_list,
topp_seed=topp_seed,
)
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## develop #7864 +/- ##
==========================================
Coverage ? 63.32%
==========================================
Files ? 462
Lines ? 64412
Branches ? 9878
==========================================
Hits ? 40786
Misses ? 20858
Partials ? 2768
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
CI报告基于以下代码生成(30分钟更新一次): 1 任务总览Required 任务当前 8/10 通过,仍有 2 个 required 失败任务 阻塞合并:主单测覆盖率失败和 Approval 未完成。请优先处理 required 失败项;optional 失败仅供参考。
2 任务状态汇总日志列说明:失败任务直接使用日志链接;运行中任务手动拼接 Job 链接。 2.1 Required任务 : 8/10 通过
2.2 可选任务 — 25/28 通过
3 失败详情(仅 required)Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage — 覆盖率阈值失败(置信度: 高)Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage
失败用例: 无,单测通过,失败发生在覆盖率校验步骤。 根因详情: 关键日志: 修复建议:
修复建议摘要: 补测 mtp_xpu/sampler 新增分支 关联变更: 链接: 查看日志 Approval — 等待人工审批(置信度: 高)该 Job 需要人工 Approval,完成审批后 CI 才会继续执行。 |
Motivation
修复 XPU 平台 MTP(Multi-Token Prediction)投机解码在 CUDAGraph 模式下运行时出现的问题,主要包括:
cu_seqlens_q_output/batch_id_per_token_output在 cudagraph 模式下应使用copy_原地更新而非重新赋值,否则破坏 graph 捕获的 tensor 地址稳定性。padding_sampling_params(Python 侧 CPU 实现)被替换为新增的build_sampling_paramsXPU 自定义算子,同时在算子内部完成infer_seed的原地更新,以避免在 cudagraph 外额外操作。Modifications
custom_ops/xpu_ops/src/ops/mtp/build_sampling_params.cc:新增 Paddle 自定义算子入口,注册build_sampling_paramsop。custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h:声明build_sampling_paramsC 接口。custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/build_sampling_params.xpu:新增 XPU3 kernel 实现。custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/build_sampling_params.cpp:新增 CPU wrapper 和 XPU3 wrapper。custom_ops/xpu_ops/test/test_build_sampling_params.py:新增单元测试,覆盖纯 decoder、纯 encoder、混合、单条、seed wrap-around 等场景。fastdeploy/model_executor/layers/sample/sampler.py:_verify_and_sample_xpu改用build_sampling_paramsXPU 算子;forward_xpu新增increment_value参数。fastdeploy/model_executor/xpu_pre_and_post_process.py:cudagraph 模式下改用copy_原地更新cu_seqlens_q_output和batch_id_per_token_output。fastdeploy/spec_decode/mtp_xpu.py:draft model 当前强制禁用 cudagraph(TODO),_propose新增 cudagraph padding 逻辑。fastdeploy/worker/xpu_model_runner.py:increment_value改为与投机解码 token 数联动;投机解码模式下infer_seed更新移入build_sampling_params算子内部。Usage or Command
N/A
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.