Skip to content

[Cherry-Pick][Feature] Support glm yarn scaling rope (#7893)#7894

Open
Sunny-bot1 wants to merge 1 commit into
PaddlePaddle:release/2.6from
Sunny-bot1:glm_yarn_26
Open

[Cherry-Pick][Feature] Support glm yarn scaling rope (#7893)#7894
Sunny-bot1 wants to merge 1 commit into
PaddlePaddle:release/2.6from
Sunny-bot1:glm_yarn_26

Conversation

@Sunny-bot1
Copy link
Copy Markdown
Collaborator

@Sunny-bot1 Sunny-bot1 commented May 22, 2026

Motivation

支持GLM 256K 长上下文推理

Modifications

  • 复用 GptOssScalingRotaryEmbedding
  • 使用更通用的名字 YarnScalingRotaryEmbedding

Usage or Command

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

python -m fastdeploy.entrypoints.openai.api_server --model ${model_path} \
    --max-model-len 262144 \
    --port 8908 \
    --tensor-parallel-size 8 \
    --max-num-seqs 64 

Accuracy Tests

"""
Needle-in-a-Haystack test for GLM YaRN 256K context.
Usage:
    python test_yarn_256k.py --host 0.0.0.0 --port 8908
    python test_yarn_256k.py --host 0.0.0.0 --port 8908 --target-tokens 131072
    python test_yarn_256k.py --host 0.0.0.0 --port 8908 --depths 0.1 0.5 0.9
"""

import argparse
import json
import time

import requests

# ---------------------------------------------------------------------------
# Filler text: ~500 chars/block, each Chinese char ≈ 1 token
# ---------------------------------------------------------------------------
_FILLER_BLOCK = (
    "春天,阳光明媚,微风轻拂,万物从冬日的沉睡中苏醒。"
    "田野里,麦苗返青,嫩绿的叶片在阳光下闪闪发光。"
    "河边的柳树抽出了新芽,细长的枝条随风摇曳,倒映在清澈的水面上。"
    "小燕子从南方飞回来了,在屋檐下忙碌地衔泥筑巢。"
    "孩子们脱掉了厚重的冬装,在草地上欢快地奔跑嬉戏。"
    "老人们坐在公园的长椅上,沐浴着温暖的阳光,脸上绽放出满足的笑容。"
    "花圃里,迎春花、桃花、杏花竞相开放,争奇斗艳,空气中弥漫着淡淡的花香。"
    "这是一年中最美好的季节,充满了生机与希望。"
)  # ~260 chars ≈ 260 tokens

NEEDLE_TEMPLATE = "\n\n【关键信息】本次测试的秘密验证码是:{code},请牢记这个验证码。\n\n"
QUESTION = "请问文中提到的秘密验证码是什么?只需回答验证码本身,不要其他解释。"
NEEDLE_CODE = "YARN_LONG_CTX_TEST_7391"


def build_prompt(target_tokens: int, depth: float) -> tuple[str, int]:
    """Build a long prompt with needle inserted at `depth` position."""
    repeat = max(1, target_tokens // len(_FILLER_BLOCK) + 1)
    filler = _FILLER_BLOCK * repeat
    filler = filler[:target_tokens]  # trim to target length (chars ≈ tokens)

    insert_pos = int(len(filler) * depth)
    needle = NEEDLE_TEMPLATE.format(code=NEEDLE_CODE)

    context = filler[:insert_pos] + needle + filler[insert_pos:]
    prompt = context + "\n\n" + QUESTION

    approx_tokens = len(context)
    return prompt, approx_tokens


def chat(host: str, port: int, prompt: str, model: str, max_tokens: int = 1000) -> str:
    url = f"http://{host}:{port}/v1/chat/completions"
    payload = {
        "model": model,
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": max_tokens,
        "temperature": 0.0,
        "stream": False,
    }
    resp = requests.post(url, json=payload, timeout=300)
    resp.raise_for_status()
    data = resp.json()
    return data["choices"][0]["message"]["content"].strip()


def run_test(host: str, port: int, model: str, target_tokens: int, depths: list[float]):
    print(f"\n{'='*60}")
    print(f"  Needle-in-a-Haystack  |  target ~{target_tokens//1000}K tokens")
    print(f"  Needle code: {NEEDLE_CODE}")
    print(f"{'='*60}")
    print(f"{'Depth':>8}  {'~Tokens':>10}  {'Pass':>6}  {'Latency':>10}  Response")
    print(f"{'-'*70}")

    passed = 0
    for depth in depths:
        prompt, approx_tokens = build_prompt(target_tokens, depth)
        t0 = time.time()
        try:
            answer = chat(host, port, prompt, model)
            elapsed = time.time() - t0
            ok = NEEDLE_CODE in answer
            passed += ok
            status = "✓" if ok else "✗"
            # truncate long answers for display
            display = answer.replace("\n", " ")
            print(f"{depth:>8.1f}  {approx_tokens:>10,}  {status:>6}  {elapsed:>8.1f}s  {display}")
        except Exception as e:
            elapsed = time.time() - t0
            print(f"{depth:>8.1f}  {approx_tokens:>10,}  {'ERR':>6}  {elapsed:>8.1f}s  {e}")

    print(f"{'-'*70}")
    print(f"Result: {passed}/{len(depths)} passed\n")


def main():
    parser = argparse.ArgumentParser(description="YaRN 256K needle-in-a-haystack test")
    parser.add_argument("--host", default="0.0.0.0")
    parser.add_argument("--port", type=int, default=8908)
    parser.add_argument("--model", default="EB", help="model name in API")
    parser.add_argument(
        "--target-tokens", type=int, default=250000,
        help="approx total context tokens (default: 250000)"
    )
    parser.add_argument(
        "--depths", type=float, nargs="+", default=[0.1, 0.3, 0.5, 0.7, 0.9],
        help="needle insertion depths as fraction of context (default: 0.1 0.3 0.5 0.7 0.9)"
    )
    args = parser.parse_args()

    run_test(args.host, args.port, args.model, args.target_tokens, args.depths)


if __name__ == "__main__":
    main()

结果

============================================================
  Needle-in-a-Haystack  |  target ~250K tokens
  Needle code: YARN_LONG_CTX_TEST_7391
============================================================
   Depth     ~Tokens    Pass     Latency  Response
----------------------------------------------------------------------
     0.1     250,055       ✓      21.3s  用户在消息中提到了一个秘密验证码:“YARN_LONG_CTX_TEST_7391”。用户要求我记住这个验证码,并询问文中提到的秘密验证码是什么。用户要求只回答验证码本身,不要其他解释。  根据用户的消息,验证码是“YARN_LONG_CTX_TEST_7391”。用户要求只回答验证码本身,不要其他解释。所以我的回答应该只是这个验证码。</think>YARN_LONG_CTX_TEST_7391
     0.3     250,055       ✓      20.9s  用户提供的文本是一段关于春天的描述,重复了很多次。在文本的末尾,用户提到“【关键信息】本次测试的秘密验证码是:YARN_LONG_CTX_TEST_7391,请牢记这个验证码。”用户问:“请问文中提到的秘密验证码是什么?只需回答验证码本身,不要其他解释。”  所以,我需要从文本中提取出验证码。验证码是“YARN_LONG_CTX_TEST_7391”。用户要求只回答验证码本身,不要其他解释。因此,我应该直接回复“YARN_LONG_CTX_TEST_7391”。</think>YARN_LONG_CTX_TEST_7391
     0.5     250,055       ✓      18.7s  用户提供的文本是一段关于春天的描写,重复了很多次。在文本的末尾,有一段关键信息:“【关键信息】本次测试的秘密验证码是:YARN_LONG_CTX_TEST_7391,请牢记这个验证码。” 用户问:“请问文中提到的秘密验证码是什么?只需回答验证码本身,不要其他解释。” 所以答案就是:YARN_LONG_CTX_TEST_7391 直接给出验证码即可。</think>YARN_LONG_CTX_TEST_7391
     0.7     250,055       ✓      14.9s  用户要求我提取并返回秘密验证码。根据提示,验证码是:YARN_LONG_CTX_TEST_7391。用户明确要求只返回验证码本身,不要其他解释。</think>YARN_LONG_CTX_TEST_7391
     0.9     250,055       ✓      10.3s  用户在消息中嵌入了一个秘密验证码,并要求我回答验证码本身。验证码是:YARN_LONG_CTX_TEST_7391  用户明确要求只回答验证码本身,不要其他解释。</think>YARN_LONG_CTX_TEST_7391
----------------------------------------------------------------------
Result: 5/5 passed

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented May 22, 2026

Thanks for your contribution!

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-22 12:24:29

📋 Review 摘要

PR 概述:为 GLM 架构新增 YaRN RoPE 缩放支持,并将 GptOssScalingRotaryEmbedding 重命名为 YarnScalingRotaryEmbedding 以统一命名
变更范围fastdeploy/model_executor/layers/rotary_embedding.py
影响面 Tag[OP] [Models]

问题

级别 文件 概述
❓ 疑问 rotary_embedding.py:354 original_max_position_embeddings 键缺少存在性检查,可能 KeyError
📝 PR 规范 标题拼写错误(Supoort);描述各段均为空

📝 PR 规范检查

问题一:标题拼写错误Supoort 应为 Support,同时 Cherry-Pick 格式整体合规(Tag 正确、末尾含原 PR 号)。

标题建议(可直接复制):

  • [Cherry-Pick][Feature] Support glm yarn scaling rope (#7893)

问题二:描述各段均为空/仅占位符,不符合模板要求。

PR 描述建议(可直接复制):

## Motivation
为 GLM 系列模型(如 GLM4)新增 YaRN(Yet another RoPE extensioN)位置编码缩放支持,以适配在更长上下文场景下部署的模型配置(`rope_scaling.type == "yarn"`)。

## Modifications
-`GptOssScalingRotaryEmbedding` 类重命名为 `YarnScalingRotaryEmbedding`,统一命名风格
-`get_rope_impl``Glm` 分支中新增 rope_scaling 检测逻辑:当 `rope_scaling` 配置类型为 `yarn` 且含 `factor` 字段时,使用 `YarnScalingRotaryEmbedding` 替代原 `GlmRotaryEmbedding`

## Usage or Command
N/A

## Accuracy Tests
N/A

## Checklist

- [x] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [ ] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [x] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.

总体评价

整体实现思路清晰,GLM yarn RoPE 的条件分支逻辑合理,类重命名也无遗留引用风险。建议修复 original_max_position_embeddings 的键存在性检查,以提高配置容错性;PR 标题拼写和描述模板请一并补全。

rotary_emb_layer = YarnScalingRotaryEmbedding(
rotary_dim=yarn_rotary_dim,
base=base,
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 original_max_position_embeddings 键无存在性检查

代码已通过 "factor" in rope_scaling 进行了部分校验,但 rope_scaling["original_max_position_embeddings"] 是直接下标访问,若 GLM 模型配置中缺少该键(例如旧版 config.json),将在运行时抛出 KeyError

建议修复方式:

original_max_position_embeddings=rope_scaling.get("original_max_position_embeddings", 8192),

或在前置条件判断中一并检查该键是否存在:

and "original_max_position_embeddings" in rope_scaling

@PaddlePaddle-bot
Copy link
Copy Markdown

PaddlePaddle-bot commented May 22, 2026

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-22 14:47:42

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

Required 任务当前 1 个失败、0 个运行中、0 个等待中;阻塞合并的是覆盖率门禁,建议补齐新增分支单测后重跑。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
36(0) 36 31 4 0 1 0

2 任务状态汇总

日志列说明:失败任务直接使用日志链接;运行中/等待中任务展示 Job 或 Workflow 链接。

2.1 Required任务 : 9/10 通过

必选任务阻塞合并,失败需优先处理。

状态 任务 耗时 根因 修复建议 日志 重跑
Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage 1h13m PR问题:新增 GLM YaRN 分支覆盖率不足 补充 get_rope_impl 分支单测 Job -
其余 9 个必选任务通过 - - - - -

2.2 可选任务 — 22/26 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Run iluvatar Tests / run_iluvatar_cases 18m40s Job -
Check PR Template 9s Job -
Trigger Jenkins for PR 1m19s Job -
⏸️ CI_HPU - Workflow -
其余 22 个可选任务通过 - - -

3 失败详情(仅 required)

Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage — 覆盖率门禁(置信度: 高)

Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage

  • 状态: ❌ 失败
  • 错误类型: 覆盖率门禁
  • 置信度: 高
  • 根因摘要: 新增 GLM YaRN 分支覆盖率不足
  • 分析器: ci_analyze_unittest_fastdeploy

失败用例:

测试 错误 根因
覆盖率校验 Coverage exit code 9 单测通过,但 diff coverage 仅 14%,低于 80% 阈值

根因详情:
本次 PR 仅修改 fastdeploy/model_executor/layers/rotary_embedding.py,新增 GLM 架构下 rope_scaling.type/rope_type == "yarn" 的 YaRN 分支,并将 GptOss scaling rope 复用为 YarnScalingRotaryEmbedding。CI 日志显示单测执行已通过(TEST_EXIT_CODE=0),失败发生在 Verify Code Coverage Threshold (80%);覆盖率报告中该文件 diff coverage 为 14.2857%,未覆盖行包括 343、344、350、351、362、365,因此属于 PR 新增/修改代码覆盖率不足。

关键日志:

All tests passed
Coverage generation failed (exit code 9)
GPU Patch Coverage Details:
"fastdeploy/model_executor/layers/rotary_embedding.py": {
  "percent_covered": 14.285714285714292,
  "violation_lines": [343, 344, 350, 351, 362, 365]
}
"total_percent_covered": 14

修复建议:

  1. 在 FastDeploy 单测中补充 get_rope_impl 的最小单测,构造 architectures=["Glm..."]rope_scaling={"type"/"rope_type":"yarn", "factor":..., "original_max_position_embeddings":...}model_config,覆盖 fastdeploy/model_executor/layers/rotary_embedding.py L343-L351 的 YaRN 分支。
  2. 同一测试文件中补充 GLM 非 YaRN/无 rope_scaling 回退用例,覆盖 L362 的 GlmRotaryEmbedding 分支;如需覆盖全部 diff,再增加 GptOss 分支用例覆盖 L365。

修复建议摘要: 补充 get_rope_impl 分支单测

关联变更: fastdeploy/model_executor/layers/rotary_embedding.py L343-L365
链接: 查看日志

@Sunny-bot1 Sunny-bot1 changed the title [Cherry-Pick][Feature] Supoort glm yarn scaling rope (#7893) [Cherry-Pick][Feature] Support glm yarn scaling rope (#7893) May 22, 2026
@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 14.28571% with 6 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (release/2.6@b562b8d). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...stdeploy/model_executor/layers/rotary_embedding.py 14.28% 6 Missing ⚠️
Additional details and impacted files
@@              Coverage Diff               @@
##             release/2.6    #7894   +/-   ##
==============================================
  Coverage               ?   72.40%           
==============================================
  Files                  ?      381           
  Lines                  ?    54229           
  Branches               ?     8474           
==============================================
  Hits                   ?    39266           
  Misses                 ?    12203           
  Partials               ?     2760           
Flag Coverage Δ
GPU 72.40% <14.28%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants