Skip to content

[Feature] Support glm yarn scaling rope#7893

Open
Sunny-bot1 wants to merge 1 commit into
PaddlePaddle:developfrom
Sunny-bot1:glm_yarn
Open

[Feature] Support glm yarn scaling rope#7893
Sunny-bot1 wants to merge 1 commit into
PaddlePaddle:developfrom
Sunny-bot1:glm_yarn

Conversation

@Sunny-bot1
Copy link
Copy Markdown
Collaborator

@Sunny-bot1 Sunny-bot1 commented May 22, 2026

Motivation

支持GLM 256K 长上下文推理

Modifications

  • 复用 GptOssScalingRotaryEmbedding
  • 使用更通用的名字 YarnScalingRotaryEmbedding

Usage or Command

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

python -m fastdeploy.entrypoints.openai.api_server --model ${model_path} \
    --max-model-len 262144 \
    --port 8908 \
    --tensor-parallel-size 8 \
    --max-num-seqs 64 

Accuracy Tests

"""
Needle-in-a-Haystack test for GLM YaRN 256K context.
Usage:
    python test_yarn_256k.py --host 0.0.0.0 --port 8908
    python test_yarn_256k.py --host 0.0.0.0 --port 8908 --target-tokens 131072
    python test_yarn_256k.py --host 0.0.0.0 --port 8908 --depths 0.1 0.5 0.9
"""

import argparse
import json
import time

import requests

# ---------------------------------------------------------------------------
# Filler text: ~500 chars/block, each Chinese char ≈ 1 token
# ---------------------------------------------------------------------------
_FILLER_BLOCK = (
    "春天,阳光明媚,微风轻拂,万物从冬日的沉睡中苏醒。"
    "田野里,麦苗返青,嫩绿的叶片在阳光下闪闪发光。"
    "河边的柳树抽出了新芽,细长的枝条随风摇曳,倒映在清澈的水面上。"
    "小燕子从南方飞回来了,在屋檐下忙碌地衔泥筑巢。"
    "孩子们脱掉了厚重的冬装,在草地上欢快地奔跑嬉戏。"
    "老人们坐在公园的长椅上,沐浴着温暖的阳光,脸上绽放出满足的笑容。"
    "花圃里,迎春花、桃花、杏花竞相开放,争奇斗艳,空气中弥漫着淡淡的花香。"
    "这是一年中最美好的季节,充满了生机与希望。"
)  # ~260 chars ≈ 260 tokens

NEEDLE_TEMPLATE = "\n\n【关键信息】本次测试的秘密验证码是:{code},请牢记这个验证码。\n\n"
QUESTION = "请问文中提到的秘密验证码是什么?只需回答验证码本身,不要其他解释。"
NEEDLE_CODE = "YARN_LONG_CTX_TEST_7391"


def build_prompt(target_tokens: int, depth: float) -> tuple[str, int]:
    """Build a long prompt with needle inserted at `depth` position."""
    repeat = max(1, target_tokens // len(_FILLER_BLOCK) + 1)
    filler = _FILLER_BLOCK * repeat
    filler = filler[:target_tokens]  # trim to target length (chars ≈ tokens)

    insert_pos = int(len(filler) * depth)
    needle = NEEDLE_TEMPLATE.format(code=NEEDLE_CODE)

    context = filler[:insert_pos] + needle + filler[insert_pos:]
    prompt = context + "\n\n" + QUESTION

    approx_tokens = len(context)
    return prompt, approx_tokens


def chat(host: str, port: int, prompt: str, model: str, max_tokens: int = 1000) -> str:
    url = f"http://{host}:{port}/v1/chat/completions"
    payload = {
        "model": model,
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": max_tokens,
        "temperature": 0.0,
        "stream": False,
    }
    resp = requests.post(url, json=payload, timeout=300)
    resp.raise_for_status()
    data = resp.json()
    return data["choices"][0]["message"]["content"].strip()


def run_test(host: str, port: int, model: str, target_tokens: int, depths: list[float]):
    print(f"\n{'='*60}")
    print(f"  Needle-in-a-Haystack  |  target ~{target_tokens//1000}K tokens")
    print(f"  Needle code: {NEEDLE_CODE}")
    print(f"{'='*60}")
    print(f"{'Depth':>8}  {'~Tokens':>10}  {'Pass':>6}  {'Latency':>10}  Response")
    print(f"{'-'*70}")

    passed = 0
    for depth in depths:
        prompt, approx_tokens = build_prompt(target_tokens, depth)
        t0 = time.time()
        try:
            answer = chat(host, port, prompt, model)
            elapsed = time.time() - t0
            ok = NEEDLE_CODE in answer
            passed += ok
            status = "✓" if ok else "✗"
            # truncate long answers for display
            display = answer.replace("\n", " ")
            print(f"{depth:>8.1f}  {approx_tokens:>10,}  {status:>6}  {elapsed:>8.1f}s  {display}")
        except Exception as e:
            elapsed = time.time() - t0
            print(f"{depth:>8.1f}  {approx_tokens:>10,}  {'ERR':>6}  {elapsed:>8.1f}s  {e}")

    print(f"{'-'*70}")
    print(f"Result: {passed}/{len(depths)} passed\n")


def main():
    parser = argparse.ArgumentParser(description="YaRN 256K needle-in-a-haystack test")
    parser.add_argument("--host", default="0.0.0.0")
    parser.add_argument("--port", type=int, default=8908)
    parser.add_argument("--model", default="EB", help="model name in API")
    parser.add_argument(
        "--target-tokens", type=int, default=250000,
        help="approx total context tokens (default: 250000)"
    )
    parser.add_argument(
        "--depths", type=float, nargs="+", default=[0.1, 0.3, 0.5, 0.7, 0.9],
        help="needle insertion depths as fraction of context (default: 0.1 0.3 0.5 0.7 0.9)"
    )
    args = parser.parse_args()

    run_test(args.host, args.port, args.model, args.target_tokens, args.depths)


if __name__ == "__main__":
    main()

结果

============================================================
  Needle-in-a-Haystack  |  target ~250K tokens
  Needle code: YARN_LONG_CTX_TEST_7391
============================================================
   Depth     ~Tokens    Pass     Latency  Response
----------------------------------------------------------------------
     0.1     250,055       ✓      21.3s  用户在消息中提到了一个秘密验证码:“YARN_LONG_CTX_TEST_7391”。用户要求我记住这个验证码,并询问文中提到的秘密验证码是什么。用户要求只回答验证码本身,不要其他解释。  根据用户的消息,验证码是“YARN_LONG_CTX_TEST_7391”。用户要求只回答验证码本身,不要其他解释。所以我的回答应该只是这个验证码。</think>YARN_LONG_CTX_TEST_7391
     0.3     250,055       ✓      20.9s  用户提供的文本是一段关于春天的描述,重复了很多次。在文本的末尾,用户提到“【关键信息】本次测试的秘密验证码是:YARN_LONG_CTX_TEST_7391,请牢记这个验证码。”用户问:“请问文中提到的秘密验证码是什么?只需回答验证码本身,不要其他解释。”  所以,我需要从文本中提取出验证码。验证码是“YARN_LONG_CTX_TEST_7391”。用户要求只回答验证码本身,不要其他解释。因此,我应该直接回复“YARN_LONG_CTX_TEST_7391”。</think>YARN_LONG_CTX_TEST_7391
     0.5     250,055       ✓      18.7s  用户提供的文本是一段关于春天的描写,重复了很多次。在文本的末尾,有一段关键信息:“【关键信息】本次测试的秘密验证码是:YARN_LONG_CTX_TEST_7391,请牢记这个验证码。” 用户问:“请问文中提到的秘密验证码是什么?只需回答验证码本身,不要其他解释。” 所以答案就是:YARN_LONG_CTX_TEST_7391 直接给出验证码即可。</think>YARN_LONG_CTX_TEST_7391
     0.7     250,055       ✓      14.9s  用户要求我提取并返回秘密验证码。根据提示,验证码是:YARN_LONG_CTX_TEST_7391。用户明确要求只返回验证码本身,不要其他解释。</think>YARN_LONG_CTX_TEST_7391
     0.9     250,055       ✓      10.3s  用户在消息中嵌入了一个秘密验证码,并要求我回答验证码本身。验证码是:YARN_LONG_CTX_TEST_7391  用户明确要求只回答验证码本身,不要其他解释。</think>YARN_LONG_CTX_TEST_7391
----------------------------------------------------------------------
Result: 5/5 passed

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented May 22, 2026

Thanks for your contribution!

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-22 12:24:25

📋 Review 摘要

PR 概述:为 GLM 模型新增 Yarn RoPE 缩放支持,并将 GptOssScalingRotaryEmbedding 重命名为 YarnScalingRotaryEmbedding
变更范围fastdeploy/model_executor/layers/rotary_embedding.py
影响面 Tag[OP]

问题

级别 文件 概述
🔴 Bug rotary_embedding.py:362 rope_scaling["original_max_position_embeddings"] 直接字典访问,缺少字段时抛 KeyError
📝 PR 规范 标题拼写错误;PR 描述各 section 均为空,Checklist 全未勾选

📝 PR 规范检查

存在两处规范问题:①标题存在拼写错误("Supoort" → "Support");②PR body 的 Motivation、Modifications、Usage or Command、Accuracy Tests 各节均为空,Checklist 条目全部未勾选。

标题建议(可直接复制):

  • [Feature] Support GLM Yarn scaling RoPE

PR 描述建议(可直接复制,必须复刻 checklist §D2 模板的完整结构):

## Motivation
为 GLM 系列模型(`architecture.startswith("Glm")`)新增 Yarn RoPE 缩放支持。当模型配置中 `rope_scaling.rope_type`(或 `type`)为 `"yarn"` 时,自动切换到 `YarnScalingRotaryEmbedding` 而非默认的 `GlmRotaryEmbedding`,以支持超出原始训练长度的推理场景。同时将 `GptOssScalingRotaryEmbedding` 重命名为语义更通用的 `YarnScalingRotaryEmbedding`## Modifications
- `fastdeploy/model_executor/layers/rotary_embedding.py`
  -`GptOssScalingRotaryEmbedding` 类重命名为 `YarnScalingRotaryEmbedding`
  -`get_rope_impl``Glm` 分支中,增加对 `rope_scaling.rope_type == "yarn"` 的检测逻辑,命中时实例化 `YarnScalingRotaryEmbedding` 并传入 `factor``mscale``beta_fast``beta_slow` 等参数
  - `GptOss` 分支使用重命名后的 `YarnScalingRotaryEmbedding`

## Usage or Command
N/A

## Accuracy Tests
N/A

## Checklist

- [x] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [ ] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.

总体评价

GLM Yarn RoPE 支持的整体实现思路正确,条件判断层次清晰;但存在一个 KeyError 安全风险需在合入前修复,同时 PR 描述需补全以便后续可追溯。

original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
scale=rope_scaling["factor"],
mscale=rope_scaling.get("mscale", 1.0),
beta_fast=rope_scaling.get("beta_fast", 32),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug rope_scaling["original_max_position_embeddings"] 直接字典访问,存在 KeyError 风险。

当前条件只验证了 "factor" in rope_scaling,但未检查 "original_max_position_embeddings" 是否存在。若用户的 rope_scaling 配置中缺少该字段,运行时会直接抛出 KeyError 导致服务崩溃。

建议修复(使用 .get() + 显式报错):

original_max_pos = rope_scaling.get("original_max_position_embeddings")
if original_max_pos is None:
    raise ValueError(
        "rope_scaling must contain 'original_max_position_embeddings' when rope_type is 'yarn'"
    )
rotary_emb_layer = YarnScalingRotaryEmbedding(
    rotary_dim=yarn_rotary_dim,
    base=base,
    original_max_position_embeddings=original_max_pos,
    ...
)

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 22, 2026

Codecov Report

❌ Patch coverage is 14.28571% with 6 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@3ec5011). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...stdeploy/model_executor/layers/rotary_embedding.py 14.28% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7893   +/-   ##
==========================================
  Coverage           ?   63.62%           
==========================================
  Files              ?      462           
  Lines              ?    64484           
  Branches           ?     9881           
==========================================
  Hits               ?    41029           
  Misses             ?    20674           
  Partials           ?     2781           
Flag Coverage Δ
GPU 72.76% <14.28%> (?)
XPU 7.11% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Sunny-bot1 Sunny-bot1 changed the title [Feature] Supoort glm yarn scaling rope [Feature] Support glm yarn scaling rope May 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants