Skip to content

Commit 5f36055

Browse files
committed
feat: add GLM-4.7-Flash (glm4_moe_lite) model support
- Add glm4_moe_lite model implementation with MLA attention - Add glm4_moe_lite_mtp for multi-token prediction support - Refactor attention kernels to use dynamic batch size - Add kernel configs for H200 GPU optimization - Add BFCL evaluation scripts for function calling
1 parent 871ace6 commit 5f36055

37 files changed

Lines changed: 2230 additions & 34 deletions

lightllm/common/basemodel/attention/flashinfer/mla.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def __init__(self, model):
1616
self.qk_nope_head_dim = model.qk_nope_head_dim
1717
self.qk_rope_head_dim = model.qk_rope_head_dim
1818
self.kv_lora_rank = model.kv_lora_rank
19+
# v_head_dim may differ from qk_nope_head_dim (e.g., GLM-4.7-Flash: v_head_dim=256, qk_nope_head_dim=192)
20+
self.v_head_dim = getattr(model, "v_head_dim", self.qk_nope_head_dim)
1921
self.q_data_type = model.data_type
2022
self.kv_data_type = model.data_type
2123
self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id())
@@ -69,7 +71,7 @@ def init_state(self):
6971
num_qo_heads=self.backend.tp_q_head_num,
7072
num_kv_heads=self.backend.tp_q_head_num,
7173
head_dim_qk=self.backend.qk_nope_head_dim + self.backend.qk_rope_head_dim,
72-
head_dim_vo=self.backend.qk_nope_head_dim,
74+
head_dim_vo=self.backend.v_head_dim, # Use v_head_dim, not qk_nope_head_dim
7375
q_data_type=self.backend.q_data_type,
7476
causal=True,
7577
sm_scale=self.backend.softmax_scale,
@@ -101,7 +103,8 @@ def _mla_prefill_att(
101103
) -> torch.Tensor:
102104
self.backend: MlaFlashInferAttBackend = self.backend # for typing
103105
k_nope, k_rope = k
104-
o_tensor = alloc_func((q.shape[0], q.shape[1], k_nope.shape[2]), q.dtype, device="cuda")
106+
# Output dimension is v_head_dim (from v.shape[-1]), not qk_nope_head_dim
107+
o_tensor = alloc_func((q.shape[0], q.shape[1], v.shape[-1]), q.dtype, device="cuda")
105108
q_head_num = q.shape[1]
106109
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1)
107110
self.prefill_wrapper.run(q, k, v, out=o_tensor)

lightllm/common/basemodel/attention/triton/mla.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def _mla_prefill_att(
4444

4545
qk_rope_head_dim = 64
4646
q_nope, q_rope = q[:, :, :-qk_rope_head_dim], q[:, :, -qk_rope_head_dim:]
47-
o_tensor = alloc_func(q_nope.shape, dtype=q_nope.dtype, device=q.device)
47+
# GLM-4.7-Flash : v_head_dim != qk_nope_head_dim
48+
o_tensor = alloc_func((q_nope.shape[0], q_nope.shape[1], v.shape[-1]), dtype=q_nope.dtype, device=q.device)
4849
k_nope, k_rope = k
4950
assert att_control.mla_prefill
5051
softmax_scale = att_control.mla_prefill_dict["softmax_scale"]

lightllm/common/basemodel/basemodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,7 @@ def _gen_special_model_input(self, token_num: int):
10221022
"Deepseek3MTPModel" in str(self.__class__)
10231023
or "Qwen3MOEMTPModel" in str(self.__class__)
10241024
or "MistralMTPModel" in str(self.__class__)
1025+
or "Glm4MoeLiteMTPModel" in str(self.__class__)
10251026
)
10261027
if is_mtp_draft_model:
10271028
special_model_input["mtp_draft_input_hiddens"] = torch.randn(

lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,11 @@ def _fwd_kernel_calcu_index_and_block_seq(
8181
vsm_count,
8282
batch_size,
8383
BLOCK_N: tl.constexpr,
84+
MAX_BATCH_SIZE: tl.constexpr,
8485
):
85-
b_seq_len = tl.load(b_seq_len + tl.arange(0, 2048), mask=tl.arange(0, 2048) < batch_size, other=0)
86+
b_seq_len = tl.load(
87+
b_seq_len + tl.arange(0, MAX_BATCH_SIZE), mask=tl.arange(0, MAX_BATCH_SIZE) < batch_size, other=0
88+
)
8689
total_token_num = tl.sum(b_seq_len)
8790

8891
block_seq = tl.cdiv(total_token_num, vsm_count * 4)
@@ -93,9 +96,9 @@ def _fwd_kernel_calcu_index_and_block_seq(
9396
cumsum_seq_len = tl.cumsum(block_seq_len)
9497
batch_start_index = cumsum_seq_len - block_seq_len
9598
tl.store(
96-
mid_o_batch_start_index + tl.arange(0, 2048),
99+
mid_o_batch_start_index + tl.arange(0, MAX_BATCH_SIZE),
97100
batch_start_index,
98-
mask=tl.arange(0, 2048) < batch_size,
101+
mask=tl.arange(0, MAX_BATCH_SIZE) < batch_size,
99102
)
100103
tl.store(mid_o_decode_att_block_seq, block_seq)
101104

@@ -455,7 +458,6 @@ def gqa_token_decode_attention_flash_decoding_vsm(
455458
)
456459

457460
if not hasattr(infer_state, "decode_att_block_seq"):
458-
assert batch_size <= 2048
459461
decode_att_block_seq = torch.empty(
460462
[
461463
1,
@@ -477,6 +479,7 @@ def gqa_token_decode_attention_flash_decoding_vsm(
477479
num_vsm,
478480
batch_size,
479481
BLOCK_N=run_config["BLOCK_N"],
482+
MAX_BATCH_SIZE=triton.next_power_of_2(batch_size),
480483
num_warps=4,
481484
)
482485

lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def triton_grouped_topk(
227227

228228
scores_buffer = torch.empty((token_num, total_expert_num), dtype=dtype, device="cuda")
229229
out_topk_weights = torch.empty((token_num, topk), dtype=torch.float32, device="cuda")
230-
out_topk_ids = torch.empty((token_num, topk), dtype=torch.long, device="cuda")
230+
out_topk_ids = torch.empty((token_num, topk), dtype=torch.int32, device="cuda")
231231

232232
assert total_expert_num % num_expert_group == 0
233233

lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,12 @@ def select_experts(
196196
scoring_func=scoring_func,
197197
)
198198
else:
199-
group_score_topk_num = 1
200-
# for deepseek v3
201-
if topk_group == 4 and num_expert_group == 8 and top_k == 8:
199+
if correction_bias is not None:
202200
group_score_topk_num = 2
201+
elif topk_group == 4 and num_expert_group == 8 and top_k == 8:
202+
group_score_topk_num = 2
203+
else:
204+
group_score_topk_num = 1
203205

204206
topk_weights, topk_ids = triton_grouped_topk(
205207
hidden_states=hidden_states,

lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def gqa_token_decode_attention_flash_decoding(
6767
)
6868

6969
if not hasattr(infer_state, "decode_att_block_seq"):
70-
assert batch_size <= 2048
7170
decode_att_block_seq = torch.empty(
7271
[
7372
1,
@@ -89,6 +88,7 @@ def gqa_token_decode_attention_flash_decoding(
8988
vsm_count,
9089
batch_size,
9190
BLOCK_N=BLOCK_N,
91+
MAX_BATCH_SIZE=triton.next_power_of_2(batch_size),
9292
num_warps=4,
9393
)
9494

@@ -134,8 +134,11 @@ def _fwd_kernel_calcu_index_and_block_seq(
134134
num_sm,
135135
batch_size,
136136
BLOCK_N: tl.constexpr,
137+
MAX_BATCH_SIZE: tl.constexpr,
137138
):
138-
b_seq_len = tl.load(b_seq_len_ptr + tl.arange(0, 2048), mask=tl.arange(0, 2048) < batch_size, other=0)
139+
b_seq_len = tl.load(
140+
b_seq_len_ptr + tl.arange(0, MAX_BATCH_SIZE), mask=tl.arange(0, MAX_BATCH_SIZE) < batch_size, other=0
141+
)
139142
total_token_num = tl.sum(b_seq_len)
140143

141144
block_seq = tl.cast(total_token_num / (num_sm * 4), dtype=tl.int32) + 1
@@ -144,6 +147,10 @@ def _fwd_kernel_calcu_index_and_block_seq(
144147
block_seq_len = tl.cdiv(b_seq_len, block_seq)
145148
cumsum_seq_len = tl.cumsum(block_seq_len)
146149
batch_start_index = cumsum_seq_len - block_seq_len
147-
tl.store(mid_o_batch_start_index_ptr + tl.arange(0, 2048), batch_start_index, mask=tl.arange(0, 2048) < batch_size)
150+
tl.store(
151+
mid_o_batch_start_index_ptr + tl.arange(0, MAX_BATCH_SIZE),
152+
batch_start_index,
153+
mask=tl.arange(0, MAX_BATCH_SIZE) < batch_size,
154+
)
148155
tl.store(mid_o_decode_att_block_seq_ptr, block_seq)
149156
return

lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def _fwd_kernel_with_v(
3636
BLOCK_DMODEL: tl.constexpr,
3737
BLOCK_ROPE_DMODEL: tl.constexpr,
3838
BLOCK_N: tl.constexpr,
39+
BLOCK_V_DMODEL: tl.constexpr,
40+
ACTUAL_DMODEL: tl.constexpr,
41+
ACTUAL_V_DMODEL: tl.constexpr,
3942
):
4043
cur_batch = tl.program_id(0)
4144
cur_head = tl.program_id(1)
@@ -53,8 +56,13 @@ def _fwd_kernel_with_v(
5356
# initialize offsets
5457
offs_n = tl.arange(0, BLOCK_N)
5558
offs_d = tl.arange(0, BLOCK_DMODEL)
59+
offs_v_d = tl.arange(0, BLOCK_V_DMODEL)
5660
offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL)
5761
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
62+
63+
d_mask = offs_d < ACTUAL_DMODEL
64+
v_d_mask = offs_v_d < ACTUAL_V_DMODEL
65+
5866
off_q = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_bs + cur_head * stride_q_h + offs_d[None, :]
5967
off_q_rope = (
6068
(cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_rope_bs
@@ -63,9 +71,10 @@ def _fwd_kernel_with_v(
6371
)
6472
off_k = offs_n[None, :] * stride_k_bs + cur_k_head * stride_k_h + offs_d[:, None]
6573
off_k_rope = offs_n[None, :] * stride_k_rope_bs + offs_rope_d[:, None]
66-
off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_d[None, :]
74+
off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_v_d[None, :]
6775

68-
q = tl.load(Q_nope + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
76+
q_mask = (offs_m[:, None] < cur_batch_seq_len) & d_mask[None, :]
77+
q = tl.load(Q_nope + off_q, mask=q_mask, other=0.0)
6978
q_rope = tl.load(Q_rope + off_q_rope, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
7079

7180
k_ptrs = K_nope + off_k
@@ -75,22 +84,24 @@ def _fwd_kernel_with_v(
7584
# initialize pointer to m and l
7685
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
7786
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
78-
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
87+
acc = tl.zeros([BLOCK_M, BLOCK_V_DMODEL], dtype=tl.float32)
7988

8089
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
8190
block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len)
8291

8392
for start_n in range(0, block_mask * block_end_loc, BLOCK_N):
8493
start_n = tl.multiple_of(start_n, BLOCK_N)
8594
# -- compute qk ----
95+
k_seq_mask = (start_n + offs_n[None, :]) < block_end_loc
96+
k_mask = k_seq_mask & d_mask[:, None]
8697
k = tl.load(
8798
k_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_bs,
88-
mask=(start_n + offs_n[None, :]) < block_end_loc,
99+
mask=k_mask,
89100
other=0.0,
90101
)
91102
k_rope = tl.load(
92103
k_rope_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_rope_bs,
93-
mask=(start_n + offs_n[None, :]) < block_end_loc,
104+
mask=k_seq_mask,
94105
other=0.0,
95106
)
96107

@@ -112,9 +123,11 @@ def _fwd_kernel_with_v(
112123
# -- update output accumulator --
113124
acc = acc * alpha[:, None]
114125
# update acc
126+
v_seq_mask = (start_n + offs_n[:, None]) < block_end_loc
127+
v_mask = v_seq_mask & v_d_mask[None, :]
115128
v = tl.load(
116129
v_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_vbs,
117-
mask=(start_n + offs_n[:, None]) < block_end_loc,
130+
mask=v_mask,
118131
other=0.0,
119132
)
120133
p = p.to(v.dtype)
@@ -124,9 +137,10 @@ def _fwd_kernel_with_v(
124137

125138
acc = acc / l_i[:, None]
126139
# initialize pointers to output
127-
off_o = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :]
140+
off_o = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_v_d[None, :]
128141
out_ptrs = Out + off_o
129-
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
142+
o_mask = (offs_m[:, None] < cur_batch_seq_len) & v_d_mask[None, :]
143+
tl.store(out_ptrs, acc, mask=o_mask)
130144
return
131145

132146

@@ -149,13 +163,14 @@ def context_attention_fwd_with_v(
149163
BLOCK = 128 if not is_tesla() else 64
150164
q_nope_dim = q_nope.shape[-1]
151165
q_rope_dim = q_rope.shape[-1]
166+
v_dim = v.shape[-1]
152167
assert q_nope_dim == k_nope.shape[-1]
153168
assert q_rope_dim == k_rope.shape[-1]
154-
assert q_nope_dim in {16, 32, 64, 128, 256, 512}
155-
assert q_rope_dim in {16, 32, 64, 128, 256}
156-
assert q_nope_dim == v.shape[-1]
157169

158-
if q_nope_dim >= 512:
170+
q_nope_dim_padded = triton.next_power_of_2(q_nope_dim)
171+
v_dim_padded = triton.next_power_of_2(v_dim)
172+
173+
if q_nope_dim_padded >= 512 or v_dim_padded >= 512:
159174
BLOCK = 64 if not is_tesla() else 32
160175
else:
161176
BLOCK = 128 if not is_tesla() else 64
@@ -167,7 +182,7 @@ def context_attention_fwd_with_v(
167182
batch, head = b_seq_len.shape[0], q_nope.shape[1]
168183

169184
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
170-
num_warps = 4 if q_nope_dim <= 64 else 8
185+
num_warps = 4 if q_nope_dim_padded <= 64 else 8
171186

172187
_fwd_kernel_with_v[grid](
173188
q_nope,
@@ -194,9 +209,12 @@ def context_attention_fwd_with_v(
194209
o.stride(1),
195210
b_prompt_cache_len=b_prompt_cache_len,
196211
BLOCK_M=BLOCK,
197-
BLOCK_DMODEL=q_nope_dim,
212+
BLOCK_DMODEL=q_nope_dim_padded,
198213
BLOCK_ROPE_DMODEL=q_rope_dim,
199214
BLOCK_N=BLOCK,
215+
BLOCK_V_DMODEL=v_dim_padded,
216+
ACTUAL_DMODEL=q_nope_dim,
217+
ACTUAL_V_DMODEL=v_dim,
200218
num_warps=num_warps,
201219
num_stages=1,
202220
)

lightllm/common/quantization/no_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def apply(
2828
device = input_tensor.device
2929
if use_custom_tensor_mananger:
3030
out = g_cache_manager.alloc_tensor(shape, dtype, device=device)
31-
else:
32-
out = torch.empty(shape, dtype=dtype, device=device)
31+
else:
32+
out = torch.empty(shape, dtype=dtype, device=device)
3333
if bias is None:
3434
return torch.mm(input_tensor, weight, out=out)
3535
return torch.addmm(bias, input_tensor, weight, out=out)

0 commit comments

Comments
 (0)