diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc index 02203a51cff..e49d9f32769 100644 --- a/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc @@ -119,9 +119,11 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids, msg_sed.mtype = 1; msg_sed.meta[0] = not_need_stop.data()[0] ? inference_msg_id_from_env : -inference_msg_id_from_env; - msg_sed.meta[1] = message_flag; - msg_sed.meta[2] = bsz; + // Pack message_flag (low 8 bits) and max_num_logprobs (high 16 bits) into + // meta[1]. Receiver unpacks both to avoid reading unused topk slots. int max_num_logprobs = logprob_token_ids.shape()[1]; + msg_sed.meta[1] = message_flag | (max_num_logprobs << 8); + msg_sed.meta[2] = bsz; for (int i = 0; i < bsz; i++) { int cur_token_num; if (seq_lens_decoder_data[i] < prompt_lens_data[i] || @@ -139,29 +141,24 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids, auto* cur_batch_msg_sed = &msg_sed.mtext[i]; int token_offset = cu_batch_token_offset_data[i]; for (int j = 0; j < cur_token_num; j++) { + // Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write + // max_num_logprobs columns to avoid filling unused topk slots. auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; if (j == 0) { // first token has full logprobs - for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + for (int k = 0; k < max_num_logprobs; k++) { if (k == 0) { cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j]; cur_scores[k] = - logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + - k]; - } else if (k < max_num_logprobs) { - // only for first token - cur_tokens[k] = - (int)logprob_token_ids_data[(token_offset + j) * - (SPEC_LOGPROB_K + 1) + - k]; - cur_scores[k] = - logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + - k]; + logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; } else { - cur_tokens[k] = -1; - cur_scores[k] = 0.0; + cur_tokens[k] = (int) + logprob_token_ids_data[(token_offset + j) * max_num_logprobs + + k]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; } } cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j]; diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc index 4fd7d4103c4..5e1a7f61886 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc @@ -75,8 +75,11 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, int bsz = msg_rcv.meta[2]; output_tokens_data[0] = (int64_t)msg_rcv.meta[0]; + // Unpack message_flag (low 8 bits) and actual_topk (high 16 bits) from + // meta[1]. Keep packed value; Python unpacks message_flag and actual_topk. output_tokens_data[1] = (int64_t)msg_rcv.meta[1]; output_tokens_data[2] = (int64_t)msg_rcv.meta[2]; + int actual_topk = (msg_rcv.meta[1] >> 8) & 0xFFFF; int output_tokens_offset = 3 + SPEC_LOGPROB_MAX_BSZ; for (int i = 0; i < bsz; i++) { @@ -89,7 +92,7 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)); auto* cur_batch_msg_rcv = &msg_rcv.mtext[i]; for (int j = 0; j < cur_token_num; j++) { - for (int k = 0; k < real_k + 1; k++) { + for (int k = 0; k < actual_topk; k++) { cur_output_token[j * (SPEC_LOGPROB_K + 1) + k] = (int64_t)cur_batch_msg_rcv->tokens[j * (SPEC_LOGPROB_K + 1) + k]; cur_output_score[j * (SPEC_LOGPROB_K + 1) + k] = diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc index 0b3de384cee..08dc85f2e8f 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc @@ -121,9 +121,11 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, msg_sed.mtype = 1; msg_sed.meta[0] = not_need_stop.data()[0] ? inference_msg_id_from_env : -inference_msg_id_from_env; - msg_sed.meta[1] = message_flag; - msg_sed.meta[2] = bsz; + // Pack message_flag (low 8 bits) and max_num_logprobs (high 16 bits) into + // meta[1]. Receiver unpacks both to avoid reading unused topk slots. int max_num_logprobs = logprob_token_ids.shape()[1]; + msg_sed.meta[1] = message_flag | (max_num_logprobs << 8); + msg_sed.meta[2] = bsz; for (int i = 0; i < bsz; i++) { int cur_token_num; if (seq_lens_decoder_data[i] < prompt_lens_data[i]) { @@ -139,24 +141,20 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, auto* cur_batch_msg_sed = &msg_sed.mtext[i]; int token_offset = cu_batch_token_offset_data[i]; for (int j = 0; j < cur_token_num; j++) { + // Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write + // max_num_logprobs columns to avoid filling unused topk slots. auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; - for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + for (int k = 0; k < max_num_logprobs; k++) { if (k == 0) { cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j]; cur_scores[k] = - logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + - k]; - } else if (k < max_num_logprobs) { + logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; + } else { cur_tokens[k] = (int) - logprob_token_ids_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + - k]; + logprob_token_ids_data[(token_offset + j) * max_num_logprobs + k]; cur_scores[k] = - logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + - k]; - } else { - cur_tokens[k] = -1; - cur_scores[k] = 0.0; + logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; } } cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j]; diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 65d3cb8d6ba..8df545e7d63 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -735,12 +735,15 @@ def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores, metrics=None, ) - token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] + tokens_i = tokens[i].tolist() + scores_i = scores[i].tolist() + ranks_i = ranks[i].tolist() + token_ids = [row[0] for row in tokens_i[: accept_num[i]]] for batch_token_index in range(len(token_ids)): - result.outputs.logprob = float(scores[i, batch_token_index, 0]) - topk_token_ids = tokens[i, batch_token_index, :].tolist() - topk_logprobs = scores[i, batch_token_index, :].tolist() - sampled_rank = ranks[i, batch_token_index].item() + result.outputs.logprob = scores_i[batch_token_index][0] + topk_token_ids = tokens_i[batch_token_index] + topk_logprobs = scores_i[batch_token_index] + sampled_rank = ranks_i[batch_token_index] if result.outputs.draft_top_logprobs is None: result.outputs.draft_top_logprobs = LogprobsLists( @@ -767,16 +770,19 @@ def _process_batch_output(self): mtype = 3 if self.cfg.speculative_config.method: if self.use_logprobs: - mtype = int(self.output_tokens[1, 0].item()) + # meta[1] packs message_flag (low 8 bits) and actual_topk (high 16 bits). + packed_meta1 = int(self.output_tokens[1, 0].item()) + mtype = packed_meta1 & 0xFF + actual_topk = (packed_meta1 >> 8) & 0xFFFF batch = self.output_tokens[2, 0] accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]] tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape( [batch, MAX_DRAFT_TOKENS, K + 1] - ) + )[:, :, :actual_topk] scores = ( self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)] .numpy() - .reshape([batch, MAX_DRAFT_TOKENS, K + 1]) + .reshape([batch, MAX_DRAFT_TOKENS, K + 1])[:, :, :actual_topk] ) ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS]) @@ -785,6 +791,10 @@ def _process_batch_output(self): batch_result = self._process_batch_draft_tokens(mtype, batch, accept_num, tokens, scores, ranks) self.postprocess(batch_result, mtype) return + # Pre-convert full arrays to Python lists once for MTP target token path. + tokens_lists = tokens.tolist() + scores_lists = scores.tolist() + ranks_list = ranks.tolist() else: batch = self.output_tokens[1] accept_num = tokens[2 : batch + 2] @@ -852,7 +862,7 @@ def _process_batch_output(self): ) token_ids = [RECOVERY_STOP_SIGNAL] elif self.use_logprobs: - token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] + token_ids = [row[0] for row in tokens_lists[i][: accept_num[i]]] else: token_ids = tokens[ 2 @@ -984,10 +994,10 @@ def _process_batch_output(self): task.output_token_ids.append(token_id) if self.use_logprobs: if self.cfg.speculative_config.method: - result.outputs.logprob = float(scores[i, batch_token_index, 0]) - topk_token_ids = tokens[i, batch_token_index, :].tolist() - topk_logprobs = scores[i, batch_token_index, :].tolist() - sampled_rank = ranks[i, batch_token_index].item() + result.outputs.logprob = scores_lists[i][batch_token_index][0] + topk_token_ids = tokens_lists[i][batch_token_index] + topk_logprobs = scores_lists[i][batch_token_index] + sampled_rank = ranks_list[i][batch_token_index] else: # Use pre-converted lists (batch .tolist() done before the loop). result.outputs.logprob = scores_lists[i][0] diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 172e61808be..d2d2ab3f54e 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1275,15 +1275,11 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p req.sampling_params.top_p_normalized_logprobs and req.sampling_params.top_p != 1.0 for req in logprobs_reqs ) if logprobs_reqs: - self.max_logprobs = ( - max( - [ - self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs - for req in logprobs_reqs - ] - ) - if not self.speculative_decoding - else 20 + self.max_logprobs = max( + [ + self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs + for req in logprobs_reqs + ] ) elif self.enable_logprob: self.max_logprobs = None if not self.speculative_decoding else 0