Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,11 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
msg_sed.mtype = 1;
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
: -inference_msg_id_from_env;
msg_sed.meta[1] = message_flag;
msg_sed.meta[2] = bsz;
// Pack message_flag (low 8 bits) and max_num_logprobs (high 16 bits) into
// meta[1]. Receiver unpacks both to avoid reading unused topk slots.
int max_num_logprobs = logprob_token_ids.shape()[1];
msg_sed.meta[1] = message_flag | (max_num_logprobs << 8);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 message_flag 是否保证只使用低 8 位(< 256)?

当前打包方式:msg_sed.meta[1] = message_flag | (max_num_logprobs << 8),低8位给 message_flag,高16位给 max_num_logprobs。若 message_flag 曾被赋值超过 255,两个字段的 bit 会互相污染,导致 Python 侧解包得到错误的 mtypeactual_topk

建议:显式断言或掩码保护:

assert((message_flag & 0xFF) == message_flag); //
msg_sed.meta[1] = (message_flag & 0xFF) | (max_num_logprobs << 8);

msg_sed.meta[2] = bsz;
for (int i = 0; i < bsz; i++) {
int cur_token_num;
if (seq_lens_decoder_data[i] < prompt_lens_data[i] ||
Expand All @@ -139,29 +141,24 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
int token_offset = cu_batch_token_offset_data[i];
for (int j = 0; j < cur_token_num; j++) {
// Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write
// max_num_logprobs columns to avoid filling unused topk slots.
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)];
auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)];
if (j == 0) {
// first token has full logprobs
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
for (int k = 0; k < max_num_logprobs; k++) {
if (k == 0) {
cur_tokens[k] =
(int)sampled_token_ids_data[i * max_draft_tokens + j];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
} else if (k < max_num_logprobs) {
// only for first token
cur_tokens[k] =
(int)logprob_token_ids_data[(token_offset + j) *
(SPEC_LOGPROB_K + 1) +
k];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
} else {
cur_tokens[k] = -1;
cur_scores[k] = 0.0;
cur_tokens[k] = (int)
logprob_token_ids_data[(token_offset + j) * max_num_logprobs +
k];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
}
}
cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,11 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,

int bsz = msg_rcv.meta[2];
output_tokens_data[0] = (int64_t)msg_rcv.meta[0];
// Unpack message_flag (low 8 bits) and actual_topk (high 16 bits) from
// meta[1]. Keep packed value; Python unpacks message_flag and actual_topk.
output_tokens_data[1] = (int64_t)msg_rcv.meta[1];
output_tokens_data[2] = (int64_t)msg_rcv.meta[2];
int actual_topk = (msg_rcv.meta[1] >> 8) & 0xFFFF;

int output_tokens_offset = 3 + SPEC_LOGPROB_MAX_BSZ;
for (int i = 0; i < bsz; i++) {
Expand All @@ -89,7 +92,7 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1));
auto* cur_batch_msg_rcv = &msg_rcv.mtext[i];
for (int j = 0; j < cur_token_num; j++) {
for (int k = 0; k < real_k + 1; k++) {
for (int k = 0; k < actual_topk; k++) {
cur_output_token[j * (SPEC_LOGPROB_K + 1) + k] =
(int64_t)cur_batch_msg_rcv->tokens[j * (SPEC_LOGPROB_K + 1) + k];
cur_output_score[j * (SPEC_LOGPROB_K + 1) + k] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,11 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
msg_sed.mtype = 1;
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
: -inference_msg_id_from_env;
msg_sed.meta[1] = message_flag;
msg_sed.meta[2] = bsz;
// Pack message_flag (low 8 bits) and max_num_logprobs (high 16 bits) into
// meta[1]. Receiver unpacks both to avoid reading unused topk slots.
int max_num_logprobs = logprob_token_ids.shape()[1];
msg_sed.meta[1] = message_flag | (max_num_logprobs << 8);
msg_sed.meta[2] = bsz;
for (int i = 0; i < bsz; i++) {
int cur_token_num;
if (seq_lens_decoder_data[i] < prompt_lens_data[i]) {
Expand All @@ -139,24 +141,20 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
int token_offset = cu_batch_token_offset_data[i];
for (int j = 0; j < cur_token_num; j++) {
// Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write
// max_num_logprobs columns to avoid filling unused topk slots.
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)];
auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)];
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
for (int k = 0; k < max_num_logprobs; k++) {
if (k == 0) {
cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
} else if (k < max_num_logprobs) {
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
} else {
cur_tokens[k] = (int)
logprob_token_ids_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
logprob_token_ids_data[(token_offset + j) * max_num_logprobs + k];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
} else {
cur_tokens[k] = -1;
cur_scores[k] = 0.0;
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
}
}
cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j];
Expand Down
36 changes: 23 additions & 13 deletions fastdeploy/output/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,12 +792,15 @@ def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores,
metrics=None,
)

token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
tokens_i = tokens[i].tolist()
scores_i = scores[i].tolist()
ranks_i = ranks[i].tolist()
token_ids = [row[0] for row in tokens_i[: accept_num[i]]]
for batch_token_index in range(len(token_ids)):
result.outputs.logprob = float(scores[i, batch_token_index, 0])
topk_token_ids = tokens[i, batch_token_index, :].tolist()
topk_logprobs = scores[i, batch_token_index, :].tolist()
sampled_rank = ranks[i, batch_token_index].item()
result.outputs.logprob = scores_i[batch_token_index][0]
topk_token_ids = tokens_i[batch_token_index]
topk_logprobs = scores_i[batch_token_index]
sampled_rank = ranks_i[batch_token_index]

if result.outputs.draft_top_logprobs is None:
result.outputs.draft_top_logprobs = LogprobsLists(
Expand All @@ -824,16 +827,19 @@ def _process_batch_output(self):
mtype = 3
if self.cfg.speculative_config.method:
if self.use_logprobs:
mtype = int(self.output_tokens[1, 0].item())
# meta[1] packs message_flag (low 8 bits) and actual_topk (high 16 bits).
packed_meta1 = int(self.output_tokens[1, 0].item())
mtype = packed_meta1 & 0xFF
actual_topk = (packed_meta1 >> 8) & 0xFFFF
batch = self.output_tokens[2, 0]
accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]]
tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape(
[batch, MAX_DRAFT_TOKENS, K + 1]
)
)[:, :, :actual_topk]
scores = (
self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)]
.numpy()
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])[:, :, :actual_topk]
)
ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS])

Expand All @@ -842,6 +848,10 @@ def _process_batch_output(self):
batch_result = self._process_batch_draft_tokens(mtype, batch, accept_num, tokens, scores, ranks)
self.postprocess(batch_result, mtype)
return
# Pre-convert full arrays to Python lists once for MTP target token path.
tokens_lists = tokens.tolist()
scores_lists = scores.tolist()
ranks_list = ranks.tolist()
else:
batch = self.output_tokens[1]
accept_num = tokens[2 : batch + 2]
Expand Down Expand Up @@ -910,7 +920,7 @@ def _process_batch_output(self):
llm_logger.info(f"recovery stop signal found at task {task_id}")
token_ids = [RECOVERY_STOP_SIGNAL]
elif self.use_logprobs:
token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
token_ids = [row[0] for row in tokens_lists[i][: accept_num[i]]]
else:
token_ids = tokens[
2
Expand Down Expand Up @@ -1029,10 +1039,10 @@ def _process_batch_output(self):
task.output_token_ids.append(token_id)
if self.use_logprobs:
if self.cfg.speculative_config.method:
result.outputs.logprob = float(scores[i, batch_token_index, 0])
topk_token_ids = tokens[i, batch_token_index, :].tolist()
topk_logprobs = scores[i, batch_token_index, :].tolist()
sampled_rank = ranks[i, batch_token_index].item()
result.outputs.logprob = scores_lists[i][batch_token_index][0]
topk_token_ids = tokens_lists[i][batch_token_index]
topk_logprobs = scores_lists[i][batch_token_index]
sampled_rank = ranks_list[i][batch_token_index]
else:
# Use pre-converted lists (batch .tolist() done before the loop).
result.outputs.logprob = scores_lists[i][0]
Expand Down
14 changes: 5 additions & 9 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,15 +1226,11 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p
req.sampling_params.top_p_normalized_logprobs and req.sampling_params.top_p != 1.0 for req in logprobs_reqs
)
if logprobs_reqs:
self.max_logprobs = (
max(
[
self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs
for req in logprobs_reqs
]
)
if not self.speculative_decoding
else 20
self.max_logprobs = max(
[
self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs
for req in logprobs_reqs
]
)

This comment was marked as outdated.

This comment was marked as outdated.

elif self.enable_logprob:
self.max_logprobs = None if not self.speculative_decoding else 0
Expand Down
Loading