Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,11 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
msg_sed.mtype = 1;
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
: -inference_msg_id_from_env;
msg_sed.meta[1] = message_flag;
msg_sed.meta[2] = bsz;
// Pack message_flag (low 8 bits) and max_num_logprobs (high 16 bits) into
// meta[1]. Receiver unpacks both to avoid reading unused topk slots.
int max_num_logprobs = logprob_token_ids.shape()[1];
msg_sed.meta[1] = message_flag | (max_num_logprobs << 8);
msg_sed.meta[2] = bsz;
for (int i = 0; i < bsz; i++) {
int cur_token_num;
if (seq_lens_decoder_data[i] < prompt_lens_data[i] ||
Expand All @@ -139,29 +141,24 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
int token_offset = cu_batch_token_offset_data[i];
for (int j = 0; j < cur_token_num; j++) {
// Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write
// max_num_logprobs columns to avoid filling unused topk slots.
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)];
auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)];
if (j == 0) {
// first token has full logprobs
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
for (int k = 0; k < max_num_logprobs; k++) {
if (k == 0) {
cur_tokens[k] =
(int)sampled_token_ids_data[i * max_draft_tokens + j];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
} else if (k < max_num_logprobs) {
// only for first token
cur_tokens[k] =
(int)logprob_token_ids_data[(token_offset + j) *
(SPEC_LOGPROB_K + 1) +
k];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
} else {
cur_tokens[k] = -1;
cur_scores[k] = 0.0;
cur_tokens[k] = (int)
logprob_token_ids_data[(token_offset + j) * max_num_logprobs +
k];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
}
}
cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,11 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,

int bsz = msg_rcv.meta[2];
output_tokens_data[0] = (int64_t)msg_rcv.meta[0];
// Unpack message_flag (low 8 bits) and actual_topk (high 16 bits) from
// meta[1]. Keep packed value; Python unpacks message_flag and actual_topk.
output_tokens_data[1] = (int64_t)msg_rcv.meta[1];
output_tokens_data[2] = (int64_t)msg_rcv.meta[2];
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 real_k 参数是否已成死代码?

本次修改将循环条件从 k < real_k + 1 改为 k < actual_topkreal_k 参数不再出现在循环体中。但 Python 侧在调用此算子时仍传入 K=20token_processor.py 中的常量)作为 real_k 属性。

请确认 real_k 是否还被用于输出 tensor 的预分配或其他地方;若已完全废弃,建议清理该参数以避免后续维护误解。

int actual_topk = (msg_rcv.meta[1] >> 8) & 0xFFFF;

int output_tokens_offset = 3 + SPEC_LOGPROB_MAX_BSZ;
for (int i = 0; i < bsz; i++) {
Expand All @@ -89,7 +92,7 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1));
auto* cur_batch_msg_rcv = &msg_rcv.mtext[i];
for (int j = 0; j < cur_token_num; j++) {
for (int k = 0; k < real_k + 1; k++) {
for (int k = 0; k < actual_topk; k++) {
cur_output_token[j * (SPEC_LOGPROB_K + 1) + k] =
(int64_t)cur_batch_msg_rcv->tokens[j * (SPEC_LOGPROB_K + 1) + k];
cur_output_score[j * (SPEC_LOGPROB_K + 1) + k] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,11 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
msg_sed.mtype = 1;
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
: -inference_msg_id_from_env;
msg_sed.meta[1] = message_flag;
msg_sed.meta[2] = bsz;
// Pack message_flag (low 8 bits) and max_num_logprobs (high 16 bits) into
// meta[1]. Receiver unpacks both to avoid reading unused topk slots.
int max_num_logprobs = logprob_token_ids.shape()[1];
msg_sed.meta[1] = message_flag | (max_num_logprobs << 8);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 max_num_logprobs 打包进 16 位高位存在截断风险

max_num_logprobs > 65535 时(例如用户请求全词表 logprobs,vocab_size > 65K),接收端 (meta[1] >> 8) & 0xFFFF 会静默截断,导致 actual_topk 错误,进而越界读写或丢失数据。

旧代码中 speculative decoding 的 max_logprobs 硬限为 20,本次 gpu_model_runner.py 移除该上限后,此场景已可触发。

建议在此处加防御断言:

PD_CHECK(max_num_logprobs <= 0xFFFF,
         "max_num_logprobs %d exceeds 16-bit limit", max_num_logprobs);

msg_sed.meta[2] = bsz;
for (int i = 0; i < bsz; i++) {
int cur_token_num;
if (seq_lens_decoder_data[i] < prompt_lens_data[i]) {
Expand All @@ -139,24 +141,20 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
int token_offset = cu_batch_token_offset_data[i];
for (int j = 0; j < cur_token_num; j++) {
// Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write
// max_num_logprobs columns to avoid filling unused topk slots.
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)];
auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)];
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
for (int k = 0; k < max_num_logprobs; k++) {
if (k == 0) {
cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
} else if (k < max_num_logprobs) {
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
} else {
cur_tokens[k] = (int)
logprob_token_ids_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
logprob_token_ids_data[(token_offset + j) * max_num_logprobs + k];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
} else {
cur_tokens[k] = -1;
cur_scores[k] = 0.0;
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
}
}
cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j];
Expand Down
36 changes: 23 additions & 13 deletions fastdeploy/output/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,12 +735,15 @@ def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores,
metrics=None,
)

token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
tokens_i = tokens[i].tolist()
scores_i = scores[i].tolist()
ranks_i = ranks[i].tolist()
token_ids = [row[0] for row in tokens_i[: accept_num[i]]]
for batch_token_index in range(len(token_ids)):
result.outputs.logprob = float(scores[i, batch_token_index, 0])
topk_token_ids = tokens[i, batch_token_index, :].tolist()
topk_logprobs = scores[i, batch_token_index, :].tolist()
sampled_rank = ranks[i, batch_token_index].item()
result.outputs.logprob = scores_i[batch_token_index][0]
topk_token_ids = tokens_i[batch_token_index]
topk_logprobs = scores_i[batch_token_index]
sampled_rank = ranks_i[batch_token_index]

if result.outputs.draft_top_logprobs is None:
result.outputs.draft_top_logprobs = LogprobsLists(
Expand All @@ -767,16 +770,19 @@ def _process_batch_output(self):
mtype = 3
if self.cfg.speculative_config.method:
if self.use_logprobs:
mtype = int(self.output_tokens[1, 0].item())
# meta[1] packs message_flag (low 8 bits) and actual_topk (high 16 bits).
packed_meta1 = int(self.output_tokens[1, 0].item())
mtype = packed_meta1 & 0xFF
actual_topk = (packed_meta1 >> 8) & 0xFFFF
batch = self.output_tokens[2, 0]
accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]]
tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape(
[batch, MAX_DRAFT_TOKENS, K + 1]
)
)[:, :, :actual_topk]
scores = (
self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)]
.numpy()
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])[:, :, :actual_topk]
)
ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS])

Expand All @@ -785,6 +791,10 @@ def _process_batch_output(self):
batch_result = self._process_batch_draft_tokens(mtype, batch, accept_num, tokens, scores, ranks)
self.postprocess(batch_result, mtype)
return
# Pre-convert full arrays to Python lists once for MTP target token path.
tokens_lists = tokens.tolist()
scores_lists = scores.tolist()
ranks_list = ranks.tolist()
else:
batch = self.output_tokens[1]
accept_num = tokens[2 : batch + 2]
Expand Down Expand Up @@ -852,7 +862,7 @@ def _process_batch_output(self):
)
token_ids = [RECOVERY_STOP_SIGNAL]
elif self.use_logprobs:
token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
token_ids = [row[0] for row in tokens_lists[i][: accept_num[i]]]
else:
token_ids = tokens[
2
Expand Down Expand Up @@ -984,10 +994,10 @@ def _process_batch_output(self):
task.output_token_ids.append(token_id)
if self.use_logprobs:
if self.cfg.speculative_config.method:
result.outputs.logprob = float(scores[i, batch_token_index, 0])
topk_token_ids = tokens[i, batch_token_index, :].tolist()
topk_logprobs = scores[i, batch_token_index, :].tolist()
sampled_rank = ranks[i, batch_token_index].item()
result.outputs.logprob = scores_lists[i][batch_token_index][0]
topk_token_ids = tokens_lists[i][batch_token_index]
topk_logprobs = scores_lists[i][batch_token_index]
sampled_rank = ranks_list[i][batch_token_index]
else:
# Use pre-converted lists (batch .tolist() done before the loop).
result.outputs.logprob = scores_lists[i][0]
Expand Down
14 changes: 5 additions & 9 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,15 +1275,11 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p
req.sampling_params.top_p_normalized_logprobs and req.sampling_params.top_p != 1.0 for req in logprobs_reqs
)
if logprobs_reqs:
self.max_logprobs = (
max(
[
self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs
for req in logprobs_reqs
]
)
if not self.speculative_decoding
else 20
self.max_logprobs = max(
[
self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs
for req in logprobs_reqs
]
)
elif self.enable_logprob:
self.max_logprobs = None if not self.speculative_decoding else 0
Expand Down
Loading