Skip to content

Commit b2641ae

Browse files
author
niushengxiao
committed
feat: fp8kv support
1 parent f5ee4c3 commit b2641ae

17 files changed

Lines changed: 1952 additions & 1592 deletions

File tree

docs/CN/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ Lightllm 整合了众多的开源方案的优点,包括但不限于 FasterTran
4949
:caption: 部署教程
5050

5151
DeepSeek R1 部署 <tutorial/deepseek_deployment>
52+
FP8 KV 量化与校准 <tutorial/fp8_kv_quantization>
5253
多级缓存部署 <tutorial/multi_level_cache_deployment>
5354
多模态部署 <tutorial/multimodal>
5455
奖励模型部署 <tutorial/reward_model>
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
.. _tutorial/fp8_kv_quantization_cn:
2+
3+
FP8 KV 量化与校准指南
4+
======================
5+
6+
本章节介绍 LightLLM 中 FP8 KV 量化的完整流程,包括:
7+
8+
- 导出校准文件(``--export_fp8kv_calibration``)
9+
- 使用校准文件进行推理(``fp8kv``)
10+
- FA3 与 FlashInfer 后端下的量化粒度差异
11+
- 常见报错与排查建议
12+
13+
功能概览
14+
--------
15+
16+
LightLLM 的 FP8 KV 量化采用离线校准方案:
17+
18+
1. 先运行导出模式,统计 KV 的最大绝对值并导出 ``kv_cache_calib.json``。
19+
2. 再在推理模式加载该文件,将 KV 按 scale 量化为 ``float8_e4m3fn`` 存储。
20+
21+
后端与量化粒度
22+
--------------
23+
24+
当前行为如下:
25+
26+
- ``fa3``: 使用 ``per_head``(每个 head 独立 scale)
27+
- ``flashinfer``: 使用 ``per_tensor``(K/V 各一个标量 scale)
28+
29+
因此,校准文件与后端强相关:
30+
31+
- ``fa3`` 生成的 ``per_head`` 校准文件用于 ``fa3`` 推理。
32+
- ``flashinfer`` 生成的 ``per_tensor`` 校准文件用于 ``flashinfer`` 推理。
33+
34+
不建议混用不同后端导出的校准文件。
35+
36+
步骤一:导出校准文件
37+
--------------------
38+
39+
导出模式示例(FA3):
40+
41+
.. code-block:: console
42+
43+
$ python -m lightllm.server.api_server \
44+
--model_dir /path/to/model \
45+
--export_fp8kv_calibration \
46+
--llm_prefill_att_backend fa3 \
47+
--llm_decode_att_backend fa3 \
48+
--disable_cudagraph
49+
50+
导出模式示例(FlashInfer):
51+
52+
.. code-block:: console
53+
54+
$ python -m lightllm.server.api_server \
55+
--model_dir /path/to/model \
56+
--export_fp8kv_calibration \
57+
--llm_prefill_att_backend flashinfer \
58+
--llm_decode_att_backend flashinfer \
59+
--disable_cudagraph
60+
61+
说明:
62+
63+
- 设置 ``--export_fp8kv_calibration`` 后,会在运行过程中收集 KV 统计信息。
64+
- 校准完成后,会在当前工作目录输出 ``kv_cache_calib.json``。
65+
- 导出模式要求 ``--disable_cudagraph``,且 ``--llm_kv_type`` 保持为 ``None``。
66+
- 仓库 ``test/advanced_config/`` 目录中已存放常用模型的校准文件,可按需直接使用或作为参考。
67+
68+
使用 benchmark_qps.py 进行随机数据校准
69+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
70+
71+
除了在线业务流量,也可以使用 ``test/benchmark/service/benchmark_qps.py`` 工具构造随机请求进行校准。
72+
73+
- 默认累计约 4000 次推理后会输出一次校准结果。
74+
- 实践中可执行以下命令两次,以更稳定地覆盖统计范围。
75+
76+
示例命令:
77+
78+
.. code-block:: console
79+
80+
$ python test/benchmark/service/benchmark_qps.py --url http://127.0.0.1:8000/generate_stream --tokenizer_path ../Qwen3-30B-A3B --input_len 1000 --output_len 2000 --input_qps 10 --input_num 200 --range_ratio 0.9
81+
82+
步骤二:使用校准文件启动 FP8 推理
83+
---------------------------------
84+
85+
推理模式示例(FA3):
86+
87+
.. code-block:: console
88+
89+
$ python -m lightllm.server.api_server \
90+
--model_dir /path/to/model \
91+
--llm_kv_type fp8kv \
92+
--llm_prefill_att_backend fa3 \
93+
--llm_decode_att_backend fa3 \
94+
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json
95+
96+
推理模式示例(FlashInfer):
97+
98+
.. code-block:: console
99+
100+
$ python -m lightllm.server.api_server \
101+
--model_dir /path/to/model \
102+
--llm_kv_type fp8kv \
103+
--llm_prefill_att_backend flashinfer \
104+
--llm_decode_att_backend flashinfer \
105+
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json
106+
107+
说明:
108+
109+
- ``fp8kv`` 模式必须提供 ``--kv_quant_calibration_config_path``。
110+
- 建议推理时的 attention backend 与导出校准时保持一致。
111+
112+
校准文件格式
113+
------------
114+
115+
导出的 ``kv_cache_calib.json`` 主要字段包括:
116+
117+
- ``quant_type``: ``per_head`` 或 ``per_tensor``
118+
- ``num_layers``: 层数
119+
- ``num_head``: 总 head 数
120+
- ``scales_shape``: scale 张量形状
121+
- ``scales``: 实际 scale 数值
122+
- ``qmin`` / ``qmax``: FP8 范围参数
123+
124+
加载校准文件时,会校验模型架构、层数、head 数及量化类型是否匹配。
125+
126+
多卡说明
127+
--------
128+
129+
在多卡(TP)场景下,系统会根据当前 rank 自动切分本地需要的 head 对应 scale。
130+
你仍然只需要提供一份全量 ``kv_cache_calib.json``。
131+
132+
常见问题
133+
--------
134+
135+
1. 启动时报错需要 ``--kv_quant_calibration_config_path``
136+
137+
说明你使用了 ``--llm_kv_type fp8kv`` 但未传入校准文件路径。
138+
139+
2. 启动时报错要求 ``--disable_cudagraph``
140+
141+
说明你使用了 ``--export_fp8kv_calibration``,该模式必须禁用 cudagraph。
142+
143+
3. 报错 ``quant_type not match``
144+
145+
通常是后端与校准文件类型不一致。例如拿 ``per_head`` 文件去跑 ``flashinfer``。
146+
147+
4. 切换后端后效果异常
148+
149+
建议按目标后端重新导出校准文件,不要跨后端复用。

docs/EN/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ Documentation List
4848
:caption: Deployment Tutorials
4949

5050
DeepSeek R1 Deployment <tutorial/deepseek_deployment>
51+
FP8 KV Quantization and Calibration <tutorial/fp8_kv_quantization>
5152
Multi-Level Cache Deployment <tutorial/multi_level_cache_deployment>
5253
Multimodal Deployment <tutorial/multimodal>
5354
Reward Model Deployment <tutorial/reward_model>
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
.. _tutorial/fp8_kv_quantization_en:
2+
3+
FP8 KV Quantization and Calibration Guide
4+
=========================================
5+
6+
This chapter describes the end-to-end FP8 KV quantization workflow in LightLLM, including:
7+
8+
- Exporting calibration data (``--export_fp8kv_calibration``)
9+
- Running inference with calibration data (``fp8kv``)
10+
- Quantization granularity differences between FA3 and FlashInfer
11+
- Common errors and troubleshooting
12+
13+
Overview
14+
--------
15+
16+
LightLLM uses an offline calibration flow for FP8 KV quantization:
17+
18+
1. Run export mode to collect KV statistics and produce ``kv_cache_calib.json``.
19+
2. Run inference mode with that file, and quantize KV into ``float8_e4m3fn`` storage.
20+
21+
Backend and Quantization Granularity
22+
------------------------------------
23+
24+
Current behavior:
25+
26+
- ``fa3``: ``per_head`` scales (independent scale per head)
27+
- ``flashinfer``: ``per_tensor`` scales (one scalar for K and one scalar for V)
28+
29+
Calibration files are backend-dependent:
30+
31+
- ``per_head`` files exported with ``fa3`` should be used with ``fa3`` inference.
32+
- ``per_tensor`` files exported with ``flashinfer`` should be used with ``flashinfer`` inference.
33+
34+
Avoid mixing calibration files across different backends.
35+
36+
Step 1: Export Calibration File
37+
--------------------------------
38+
39+
Export mode example (FA3):
40+
41+
.. code-block:: console
42+
43+
$ python -m lightllm.server.api_server \
44+
--model_dir /path/to/model \
45+
--export_fp8kv_calibration \
46+
--llm_prefill_att_backend fa3 \
47+
--llm_decode_att_backend fa3 \
48+
--disable_cudagraph
49+
50+
Export mode example (FlashInfer):
51+
52+
.. code-block:: console
53+
54+
$ python -m lightllm.server.api_server \
55+
--model_dir /path/to/model \
56+
--export_fp8kv_calibration \
57+
--llm_prefill_att_backend flashinfer \
58+
--llm_decode_att_backend flashinfer \
59+
--disable_cudagraph
60+
61+
Notes:
62+
63+
- Setting ``--export_fp8kv_calibration`` collects KV statistics during runtime.
64+
- After calibration is completed, ``kv_cache_calib.json`` is written to the current working directory.
65+
- Export mode requires ``--disable_cudagraph``, and ``--llm_kv_type`` should remain ``None``.
66+
- The repository already provides calibration files for common models under ``test/advanced_config/``, which can be used directly or as references.
67+
68+
Use benchmark_qps.py for random-data calibration
69+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
70+
71+
Besides online traffic, you can use ``test/benchmark/service/benchmark_qps.py`` to generate random requests for calibration.
72+
73+
- By default, one calibration result is exported after around 4000 inferences are accumulated.
74+
- In practice, you can run the following command twice to improve coverage stability.
75+
76+
Example command:
77+
78+
.. code-block:: console
79+
80+
$ python test/benchmark/service/benchmark_qps.py --url http://127.0.0.1:8000/generate_stream --tokenizer_path ../Qwen3-30B-A3B --input_len 1000 --output_len 2000 --input_qps 10 --input_num 200 --range_ratio 0.9
81+
82+
Step 2: Start FP8 Inference with Calibration
83+
---------------------------------------------
84+
85+
Inference mode example (FA3):
86+
87+
.. code-block:: console
88+
89+
$ python -m lightllm.server.api_server \
90+
--model_dir /path/to/model \
91+
--llm_kv_type fp8kv \
92+
--llm_prefill_att_backend fa3 \
93+
--llm_decode_att_backend fa3 \
94+
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json
95+
96+
Inference mode example (FlashInfer):
97+
98+
.. code-block:: console
99+
100+
$ python -m lightllm.server.api_server \
101+
--model_dir /path/to/model \
102+
--llm_kv_type fp8kv \
103+
--llm_prefill_att_backend flashinfer \
104+
--llm_decode_att_backend flashinfer \
105+
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json
106+
107+
Notes:
108+
109+
- ``fp8kv`` requires ``--kv_quant_calibration_config_path``.
110+
- Keep the inference backend consistent with the backend used during calibration export.
111+
112+
Calibration File Schema
113+
-----------------------
114+
115+
Key fields in ``kv_cache_calib.json``:
116+
117+
- ``quant_type``: ``per_head`` or ``per_tensor``
118+
- ``num_layers``: number of layers
119+
- ``num_head``: total number of heads
120+
- ``scales_shape``: shape of the scale tensor
121+
- ``scales``: actual scale values
122+
- ``qmin`` / ``qmax``: FP8 numeric range parameters
123+
124+
At load time, LightLLM validates architecture, layer count, head count, and quantization type.
125+
126+
Multi-GPU Note
127+
--------------
128+
129+
In multi-GPU (TP) setups, LightLLM slices the global scales to local rank heads automatically.
130+
You only need to provide one full ``kv_cache_calib.json`` file.
131+
132+
Common Issues
133+
-------------
134+
135+
1. Error says ``--kv_quant_calibration_config_path`` is required
136+
137+
You are using ``--llm_kv_type fp8kv`` without a calibration file path.
138+
139+
2. Error says ``--disable_cudagraph`` is required
140+
141+
You are using ``--export_fp8kv_calibration``; this mode requires cudagraph disabled.
142+
143+
3. ``quant_type not match`` error
144+
145+
Usually caused by backend/file mismatch (for example, using a ``per_head`` file with ``flashinfer``).
146+
147+
4. Abnormal quality after backend switch
148+
149+
Re-export calibration using the target backend instead of reusing files across backends.

lightllm/common/basemodel/attention/create_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
"fa3": Fp8Fa3AttBackend,
3737
"flashinfer": Fp8FlashInferAttBackend,
3838
},
39+
"fp8kv": {
40+
"fa3": Fp8Fa3AttBackend,
41+
"flashinfer": Fp8FlashInferAttBackend,
42+
},
3943
}
4044

4145
mla_data_type_to_backend = {

lightllm/common/basemodel/attention/fa3/fp8.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,21 @@ def _fp8_prefill_att(
8989
) -> torch.Tensor:
9090
self.backend: Fp8Fa3AttBackend = self.backend # for typing
9191

92+
q_head_num = q.shape[1]
93+
q_head_dim = q.shape[2]
94+
k_head_num = k.shape[1]
9295
q, q_scale = q_per_head_fp8_quant(
93-
q,
96+
q.reshape(q.shape[0], k_head_num, -1),
9497
self.infer_state.b_seq_len,
9598
self.cu_seqlens_q,
96-
self.mid_token_batch_ids,
99+
token_batch_ids=self.mid_token_batch_ids,
97100
)
98-
k_head_num = k.shape[1]
99101
k_head_dim = k.shape[2]
100102
cache_k = k.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn)
101103
cache_v = v.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn)
102104
layer_index = self.backend._find_layer_index(k=cache_k, v=cache_v, att_state=self)
103105
o = flash_attn_with_kvcache(
104-
q=q,
106+
q=q.reshape(-1, q_head_num, q_head_dim),
105107
k_cache=cache_k,
106108
v_cache=cache_v,
107109
page_table=self.page_table,
@@ -200,9 +202,9 @@ def _fp8_decode_att(
200202
layer_index = self.backend._find_layer_index(k=cache_k, v=cache_v, att_state=self)
201203

202204
q_head_num = q.shape[1]
203-
q, q_scale = scaled_fp8_quant(q.view(q.shape[0] * k_head_num, -1), use_per_token_if_dynamic=True)
205+
q, q_scale = scaled_fp8_quant(q.reshape(q.shape[0] * k_head_num, -1), use_per_token_if_dynamic=True)
204206
o = flash_attn_with_kvcache(
205-
q=q.view(-1, q_head_num, k_head_dim),
207+
q=q.reshape(-1, q_head_num, k_head_dim),
206208
k_cache=cache_k,
207209
v_cache=cache_v,
208210
page_table=self.page_table,

lightllm/common/basemodel/attention/flashinfer/fp8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def create_att_decode_state(self, infer_state) -> "Fp8FlashInferDecodeAttState":
2020

2121
@dataclasses.dataclass
2222
class Fp8FlashInferPrefillAttState(FlashInferPrefillAttState):
23-
offline_scales: torch.Tensor = None
23+
offline_scales: list = None
2424

2525
def init_state(self):
2626
super().init_state()
@@ -68,7 +68,7 @@ def _fp8_prefill_att(
6868

6969
@dataclasses.dataclass
7070
class Fp8FlashInferDecodeAttState(FlashInferDecodeAttState):
71-
offline_scales: torch.Tensor = None
71+
offline_scales: list = None
7272

7373
def init_state(self):
7474
super().init_state()

0 commit comments

Comments
 (0)