From 6d1c5a0fa95457c8f127cec5cd6676000737698b Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 31 Mar 2026 12:10:31 +0800 Subject: [PATCH 1/2] adapting gpu ngram operator --- custom_ops/gpu_ops/cpp_extensions.cc | 4 + .../draft_model/ngram_match_mixed.cu | 408 +++++++++++------- .../gpu_ops/speculate_decoding/ngram_match.cc | 227 ---------- .../gpu_ops/speculate_decoding/ngram_match.cu | 345 +++++++++++++++ .../speculate_decoding/ngram_match_core.cuh | 42 ++ fastdeploy/spec_decode/mtp.py | 31 +- fastdeploy/spec_decode/ngram.py | 39 +- tests/operators/test_hybrid_mtp_ngram.py | 40 +- tests/operators/test_ngram_match.py | 16 +- 9 files changed, 725 insertions(+), 427 deletions(-) delete mode 100644 custom_ops/gpu_ops/speculate_decoding/ngram_match.cc create mode 100644 custom_ops/gpu_ops/speculate_decoding/ngram_match.cu create mode 100644 custom_ops/gpu_ops/speculate_decoding/ngram_match_core.cuh diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 40898434bf1..e8b1f76ec63 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -939,7 +939,9 @@ void NgramMatch(const paddle::Tensor& input_ids, const paddle::Tensor& step_idx, const paddle::Tensor& draft_token_num, const paddle::Tensor& draft_tokens, + const paddle::Tensor& draft_tokens_copy, const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_this_time_copy, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& max_dec_len, @@ -952,7 +954,9 @@ void HybridMtpNgram(const paddle::Tensor& input_ids, const paddle::Tensor& step_idx, const paddle::Tensor& draft_token_num, const paddle::Tensor& draft_tokens, + const paddle::Tensor& draft_tokens_copy, const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_this_time_copy, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& max_dec_len, const int max_ngram_size, diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu b/custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu index 529bfd9ab0e..1bb18107e04 100644 --- a/custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu @@ -12,158 +12,242 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include #include -#include #include +#include +#include +#include #include "paddle/extension.h" +#include "../ngram_match_core.cuh" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif -int sum_mixed(const int *value, int num) { - int sum_value = 0; - for (int i = 0; i <= num; i++) { - sum_value += value[i]; - } - return sum_value; -} +static __device__ int d_mixed_unprocessed_batch_size; -void find_candidate_pred_tokens_mixed(const int64_t *input_ids, - const int64_t *input_ids_len, - const int64_t *pre_ids, - const int64_t *step_idx, - const int *draft_token_num, - int64_t *draft_tokens, - int32_t *seq_lens_this_time, - int32_t *seq_lens_decoder, - int64_t *max_dec_len, - int64_t input_ids_stride, - int64_t pre_ids_stride, - int64_t draft_tokens_stride, - int64_t max_batch_size, - int max_ngram_size = 3, - int min_ngram_size = 1, - const int max_draft_tokens = 10) { - int threshold = 1024; - // dynamic in future - char *env_var = getenv("SPEC_TOKENUM_THRESHOLD"); - if (env_var) { - threshold = std::stoi(env_var); - } - int unprocessed_batch_size = 0; - for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) { - if (seq_lens_decoder[batch_idx] > 0) { - unprocessed_batch_size++; +// Phase 1: Block 0 counts unprocessed batches (seq_lens_decoder > 0). +// Blocks 1..N find ngram candidates for each batch in parallel. +template +__global__ void mixed_count_and_find_candidate_kernel( + const int64_t *input_ids, + const int64_t *input_ids_len, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *draft_token_num, + int64_t *draft_tokens, + int64_t *draft_tokens_copy, + int32_t *seq_lens_this_time, + int32_t *seq_lens_this_time_copy, + int32_t *seq_lens_decoder, + int64_t *max_dec_len, + int64_t input_ids_stride, + int64_t pre_ids_stride, + int64_t draft_tokens_stride, + int max_ngram_size, + int min_ngram_size, + int max_draft_tokens, + int32_t *unprocessed_batch_size_global, + int64_t max_batch_size) { + int tid = threadIdx.x; + int bid = blockIdx.x; + + // Block 0: count unprocessed batches + if (bid == 0) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int is_unprocessed = 0; + if (tid < max_batch_size) { + if (seq_lens_decoder[tid] > 0) { + is_unprocessed = 1; + } + } + int unprocessed = BlockReduce(temp_storage).Sum(is_unprocessed); + if (tid == 0) { + *unprocessed_batch_size_global = unprocessed; } + return; } - for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) { - const int ori_seq_len_this_time = seq_lens_this_time[batch_idx]; - int max_draft_tokens_query = std::min( - static_cast(max_draft_tokens - ori_seq_len_this_time + 1), - max_dec_len[batch_idx] - step_idx[batch_idx] - 1); - - if (ori_seq_len_this_time == 0 || max_draft_tokens_query <= 0) { - continue; + + int actual_bid = bid - 1; + if (actual_bid >= max_batch_size) return; + + __shared__ int32_t s_ori_seq_len; + __shared__ bool skip; + __shared__ int s_max_draft_tokens_query; + + if (tid == 0) { + s_ori_seq_len = seq_lens_this_time[actual_bid]; + int mdtq = max_draft_tokens - s_ori_seq_len + 1; + int64_t remaining = max_dec_len[actual_bid] - step_idx[actual_bid] - 1; + if (static_cast(mdtq) > remaining) + mdtq = static_cast(remaining); + s_max_draft_tokens_query = mdtq; + + // Initialize copy with original value + seq_lens_this_time_copy[actual_bid] = s_ori_seq_len; + + skip = (s_ori_seq_len == 0 || mdtq <= 0); + } + __syncthreads(); + + if (skip) return; + + const int64_t *cur_input_ids = input_ids + actual_bid * input_ids_stride; + int64_t *cur_draft_tokens_copy = + draft_tokens_copy + actual_bid * draft_tokens_stride; + const int64_t *cur_pre_ids = pre_ids + actual_bid * pre_ids_stride; + const int64_t cur_step_idx = step_idx[actual_bid]; + const int64_t cur_input_ids_len = input_ids_len[actual_bid]; + const int ori_seq_len = s_ori_seq_len; + const int max_draft_q = s_max_draft_tokens_query; + + __shared__ int64_t shared_match_idx; + + for (int ngram_size = max_ngram_size; ngram_size >= min_ngram_size; + --ngram_size) { + if (cur_step_idx < ngram_size) continue; + + const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size); + + // Search in input_ids + if (tid == 0) shared_match_idx = 0x7FFFFFFFFFFFFFFF; + __syncthreads(); + + sliding_window_search(cur_input_ids, + ngram, + cur_input_ids_len - ngram_size, + &shared_match_idx, + tid, + ngram_size); + __syncthreads(); + + if (shared_match_idx < 0x7FFFFFFFFFFFFFFF) { + if (tid == 0) { + int64_t start_idx = shared_match_idx + ngram_size; + int64_t end_idx = start_idx + max_draft_q; + if (end_idx > cur_input_ids_len) end_idx = cur_input_ids_len; + if (start_idx < end_idx) { + int64_t count = end_idx - start_idx; + seq_lens_this_time_copy[actual_bid] = + ori_seq_len + static_cast(count); + memcpy(cur_draft_tokens_copy + ori_seq_len, + cur_input_ids + start_idx, + sizeof(int64_t) * count); + } + } + break; } - const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride; - int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride; - const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride; - const int64_t cur_step_idx = step_idx[batch_idx]; - const int64_t cur_input_ids_len = input_ids_len[batch_idx]; - unprocessed_batch_size--; - - auto sum_token_num = sum_mixed(seq_lens_this_time, batch_idx); - int left_min_token_num = unprocessed_batch_size; - - if (sum_token_num + max_draft_tokens_query + left_min_token_num > - threshold) { - int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num; - max_draft_tokens_query = - std::min(max_draft_tokens_query, tmp_max_draft_tokens); + // Search in generated tokens (pre_ids) + if (tid == 0) shared_match_idx = 0x7FFFFFFFFFFFFFFF; + __syncthreads(); + + sliding_window_search(cur_pre_ids, + ngram, + cur_step_idx - ngram_size, + &shared_match_idx, + tid, + ngram_size); + __syncthreads(); + + if (shared_match_idx < 0x7FFFFFFFFFFFFFFF) { + if (tid == 0) { + int64_t start_idx = shared_match_idx + ngram_size; + int64_t end_idx = start_idx + max_draft_q; + if (end_idx > cur_step_idx) end_idx = cur_step_idx; + if (start_idx < end_idx) { + int64_t count = end_idx - start_idx; + seq_lens_this_time_copy[actual_bid] = + ori_seq_len + static_cast(count); + memcpy(cur_draft_tokens_copy + ori_seq_len, + cur_pre_ids + start_idx, + sizeof(int64_t) * count); + } + } + break; } + } +} - if (sum_token_num + left_min_token_num >= threshold - 1) { - continue; +// Phase 2: Single block truncation with threshold. +template +__global__ void mixed_truncate_candidate( + const int64_t *step_idx, + const int *draft_token_num, + int64_t *max_dec_len, + int32_t *seq_lens_this_time, + int32_t *seq_lens_this_time_copy, + int64_t *draft_tokens, + int64_t *draft_tokens_copy, + int64_t draft_tokens_stride, + int64_t max_batch_size, + int max_draft_tokens, + int threshold, + int32_t *unprocessed_batch_size_global) { + int tid = threadIdx.x; + int is_processed = 0; + int allocating_token_num = 0; + int ori_seq_len = 0; + int max_draft_tokens_query = 0; + + if (tid < max_batch_size) { + ori_seq_len = seq_lens_this_time[tid]; + max_draft_tokens_query = max_draft_tokens - ori_seq_len + 1; + int64_t remaining = max_dec_len[tid] - step_idx[tid] - 1; + if (static_cast(max_draft_tokens_query) > remaining) + max_draft_tokens_query = static_cast(remaining); + + if (ori_seq_len > 0 && max_draft_tokens_query > 0) { + is_processed = 1; + allocating_token_num = seq_lens_this_time_copy[tid]; + } else { + allocating_token_num = ori_seq_len; } - bool match_global = false; - // apply ngram_match in input_ids - for (int ngram_size = max_ngram_size; - ngram_size >= min_ngram_size && !match_global; - --ngram_size) { - // Extract the last n tokens as our search ngram - if (cur_step_idx < ngram_size) { - continue; - } - const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size); - - // Iterate through sliding windows of size ngram_size - // bool match_input = false; - for (int64_t i = 0; i <= cur_input_ids_len - ngram_size && !match_global; - ++i) { - // Check if the current window matches the ngram - bool match_local = true; - for (int j = 0; j < ngram_size; j++) { - if (ngram[j] != cur_input_ids[i + j]) { - match_local = false; - break; - } - } - if (match_local) { - int64_t start_idx = i + ngram_size; - int64_t end_idx = - std::min(start_idx + max_draft_tokens_query, cur_input_ids_len); - if (start_idx >= end_idx) continue; + } - int64_t cur_draft_token_num = end_idx - start_idx; + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage_batch; + int processed_batch_size; + BlockScan(temp_storage_batch) + .InclusiveSum(is_processed, processed_batch_size); + __syncthreads(); - seq_lens_this_time[batch_idx] = - ori_seq_len_this_time + cur_draft_token_num; - memcpy(cur_draft_tokens + ori_seq_len_this_time, - cur_input_ids + start_idx, - sizeof(int64_t) * cur_draft_token_num); - // To break the current batch_idx for-loop - match_global = true; - break; + __shared__ typename BlockScan::TempStorage temp_storage_token; + int sum_token_num; + BlockScan(temp_storage_token) + .InclusiveSum(allocating_token_num, sum_token_num); + + if (is_processed && tid < max_batch_size) { + // Sum before this batch: prefix_sum - this_allocation + ori_seq_len + int sum_before = sum_token_num - allocating_token_num + ori_seq_len; + int unprocessed_tid = *unprocessed_batch_size_global - processed_batch_size; + + if (sum_before + unprocessed_tid < threshold - 1) { + int64_t *cur_draft_tokens = draft_tokens + tid * draft_tokens_stride; + int64_t *cur_draft_tokens_copy_ptr = + draft_tokens_copy + tid * draft_tokens_stride; + + int found_count = seq_lens_this_time_copy[tid] - ori_seq_len; + + if (sum_before + max_draft_tokens_query + unprocessed_tid > threshold) { + max_draft_tokens_query = threshold - sum_before - unprocessed_tid; + int actual_count = found_count < max_draft_tokens_query + ? found_count + : max_draft_tokens_query; + if (actual_count > 0) { + memcpy(cur_draft_tokens + ori_seq_len, + cur_draft_tokens_copy_ptr + ori_seq_len, + sizeof(int64_t) * actual_count); + seq_lens_this_time[tid] = ori_seq_len + actual_count; } - } - // apply ngram_match in generated tokens - if (!match_global) { - for (int64_t i = 0; i <= cur_step_idx - ngram_size && !match_global; - ++i) { - // Check if the current window matches the ngram - bool match_local = true; - - for (int j = 0; j < ngram_size; j++) { - if (ngram[j] != cur_pre_ids[i + j]) { - match_local = false; - break; - } - } - if (match_local) { - int64_t start_idx = i + ngram_size; - int64_t end_idx = - std::min(start_idx + max_draft_tokens_query, cur_step_idx); - - int64_t cur_draft_token_num = end_idx - start_idx; - - if (start_idx >= end_idx) continue; - // printf("match in Output with Ngram_size %d. - // %lld:[%lld,%lld]\n",ngram_size, cur_draft_token_num, start_idx, - // end_idx); - - seq_lens_this_time[batch_idx] = - ori_seq_len_this_time + cur_draft_token_num; - memcpy(cur_draft_tokens + ori_seq_len_this_time, - cur_pre_ids + start_idx, - sizeof(int64_t) * cur_draft_token_num); - match_global = true; - break; - } + } else { + if (found_count > 0) { + memcpy(cur_draft_tokens + ori_seq_len, + cur_draft_tokens_copy_ptr + ori_seq_len, + sizeof(int64_t) * found_count); + seq_lens_this_time[tid] = seq_lens_this_time_copy[tid]; } } } @@ -176,7 +260,9 @@ void HybridMtpNgram(const paddle::Tensor &input_ids, const paddle::Tensor &step_idx, const paddle::Tensor &draft_token_num, const paddle::Tensor &draft_tokens, + const paddle::Tensor &draft_tokens_copy, const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_this_time_copy, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &max_dec_len, const int max_ngram_size, @@ -193,23 +279,53 @@ void HybridMtpNgram(const paddle::Tensor &input_ids, const int64_t max_batch_size = seq_lens_this_time.shape()[0]; - find_candidate_pred_tokens_mixed( - input_ids.data(), - input_ids_len.data(), - pre_ids.data(), + int tokennum_threshold = 1024; + char *env_var = getenv("SPEC_TOKENUM_THRESHOLD"); + if (env_var) { + tokennum_threshold = std::stoi(env_var); + } + + const int NTHREADS = 1024; + + int *d_unprocessed_ptr; + cudaGetSymbolAddress(reinterpret_cast(&d_unprocessed_ptr), + d_mixed_unprocessed_batch_size); + + mixed_count_and_find_candidate_kernel + <<>>( + input_ids.data(), + input_ids_len.data(), + pre_ids.data(), + step_idx.data(), + draft_token_num.data(), + const_cast(draft_tokens.data()), + const_cast(draft_tokens_copy.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_this_time_copy.data()), + const_cast(seq_lens_decoder.data()), + const_cast(max_dec_len.data()), + input_ids_stride, + pre_ids_stride, + draft_tokens_stride, + max_ngram_size, + min_ngram_size, + max_draft_tokens, + d_unprocessed_ptr, + max_batch_size); + + mixed_truncate_candidate<<<1, NTHREADS>>>( step_idx.data(), draft_token_num.data(), - const_cast(draft_tokens.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_decoder.data()), const_cast(max_dec_len.data()), - input_ids_stride, - pre_ids_stride, + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_this_time_copy.data()), + const_cast(draft_tokens.data()), + const_cast(draft_tokens_copy.data()), draft_tokens_stride, max_batch_size, - max_ngram_size, - min_ngram_size, - max_draft_tokens); + max_draft_tokens, + tokennum_threshold, + d_unprocessed_ptr); } PD_BUILD_STATIC_OP(hybrid_mtp_ngram) @@ -219,7 +335,9 @@ PD_BUILD_STATIC_OP(hybrid_mtp_ngram) "step_idx", "draft_token_num", "draft_tokens", + "draft_tokens_copy", "seq_lens_this_time", + "seq_lens_this_time_copy", "seq_lens_decoder", "max_dec_len"}) .Attrs({"max_ngram_size: int", diff --git a/custom_ops/gpu_ops/speculate_decoding/ngram_match.cc b/custom_ops/gpu_ops/speculate_decoding/ngram_match.cc deleted file mode 100644 index 56a2d3f81c3..00000000000 --- a/custom_ops/gpu_ops/speculate_decoding/ngram_match.cc +++ /dev/null @@ -1,227 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include -#include -#include "paddle/extension.h" - -#ifndef PD_BUILD_STATIC_OP -#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) -#endif - -int sum(const int *value, int num) { - int sum_value = 0; - for (int i = 0; i <= num; i++) { - sum_value += value[i]; - } - return sum_value; -} - -void find_candidate_pred_tokens(const int64_t *input_ids, - const int64_t *input_ids_len, - const int64_t *token_ids_all, - const int64_t *prompt_lens, - const int64_t *step_idx, - const int *draft_token_num, - int64_t *draft_tokens, - int32_t *seq_lens_this_time, - int32_t *seq_lens_encoder, - int32_t *seq_lens_decoder, - int64_t *max_dec_len, - int64_t input_ids_stride, - int64_t max_model_len, - int64_t draft_tokens_stride, - int64_t max_batch_size, - int max_ngram_size = 3, - int max_draft_tokens = 10) { - int threshold = 128; - char *env_var = getenv("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"); - if (env_var) { - threshold = std::stoi(env_var); - } - int unprocessed_batch_size = 0; - for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) { - if (seq_lens_encoder[batch_idx] > 0 || seq_lens_decoder[batch_idx] > 0) { - unprocessed_batch_size++; - } - } - for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) { - max_draft_tokens = - std::min(static_cast(draft_token_num[batch_idx]), - max_dec_len[batch_idx] - step_idx[batch_idx] - 1); - if (seq_lens_encoder[batch_idx] > 0) { - continue; - } else if (seq_lens_decoder[batch_idx] == 0) { - seq_lens_this_time[batch_idx] = 0; - continue; - } - // printf("bid: %d. enc: %d. dec. %d\n", batch_idx, - // seq_lens_encoder[batch_idx], seq_lens_decoder[batch_idx]); - - const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride; - int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride; - const int64_t *cur_pre_ids = - token_ids_all + batch_idx * max_model_len + prompt_lens[batch_idx]; - const int64_t cur_step_idx = step_idx[batch_idx]; - const int64_t cur_input_ids_len = input_ids_len[batch_idx]; - seq_lens_this_time[batch_idx] = 1; - unprocessed_batch_size--; - - auto sum_token_num = sum(seq_lens_this_time, batch_idx); - int left_min_token_num = unprocessed_batch_size; - - if (sum_token_num + max_draft_tokens + left_min_token_num > threshold) { - int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num; - max_draft_tokens = tmp_max_draft_tokens < max_draft_tokens - ? tmp_max_draft_tokens - : max_draft_tokens; - } - - if (sum_token_num + left_min_token_num >= threshold - 1) { - continue; - } - - for (int ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) { - // Extract the last n tokens as our search ngram - if (cur_step_idx < ngram_size) { - continue; - } - const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size); - - // Iterate through sliding windows of size ngram_size - bool match_input = false; - for (int64_t i = 0; i <= cur_input_ids_len - ngram_size; ++i) { - // Check if the current window matches the ngram - bool match = true; - for (int j = 0; j < ngram_size; j++) { - if (ngram[j] != cur_input_ids[i + j]) { - match = false; - break; - } - } - if (match) { - int64_t start_idx = i + ngram_size; - int64_t end_idx = - std::min(start_idx + max_draft_tokens, cur_input_ids_len); - if (start_idx >= end_idx) continue; - - int64_t cur_draft_token_num = end_idx - start_idx; - - seq_lens_this_time[batch_idx] = cur_draft_token_num + 1; - memcpy(cur_draft_tokens + 1, - cur_input_ids + start_idx, - sizeof(int64_t) * cur_draft_token_num); - // To break the current batch_idx for-loop - ngram_size = 0; - match_input = true; - break; - // } - } - } - if (!match_input) { - for (int64_t i = 0; i <= cur_step_idx - ngram_size; ++i) { - // Check if the current window matches the ngram - bool match = true; - - for (int j = 0; j < ngram_size; j++) { - if (ngram[j] != cur_pre_ids[i + j]) { - match = false; - break; - } - } - - if (match) { - int64_t start_idx = i + ngram_size; - int64_t end_idx = - std::min(start_idx + max_draft_tokens, cur_step_idx); - int64_t cur_draft_token_num = end_idx - start_idx; - if (start_idx >= end_idx) continue; - - seq_lens_this_time[batch_idx] = cur_draft_token_num + 1; - memcpy(cur_draft_tokens + 1, - cur_pre_ids + start_idx, - sizeof(int64_t) * cur_draft_token_num); - ngram_size = 0; - break; - } - } - } - } - } -} - -void NgramMatch(const paddle::Tensor &input_ids, - const paddle::Tensor &input_ids_len, - const paddle::Tensor &token_ids_all, - const paddle::Tensor &prompt_lens, - const paddle::Tensor &step_idx, - const paddle::Tensor &draft_token_num, - const paddle::Tensor &draft_tokens, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &max_dec_len, - const int max_ngram_size, - const int max_draft_tokens) { - auto input_ids_shape = input_ids.shape(); - const int64_t input_ids_stride = input_ids_shape[1]; - - const int64_t max_model_len = token_ids_all.shape()[1]; - - auto draft_tokens_shape = draft_tokens.shape(); - const int64_t draft_tokens_stride = draft_tokens_shape[1]; - - const int64_t max_batch_size = seq_lens_this_time.shape()[0]; - - find_candidate_pred_tokens( - input_ids.data(), - input_ids_len.data(), - token_ids_all.data(), - prompt_lens.data(), - step_idx.data(), - draft_token_num.data(), - const_cast(draft_tokens.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(max_dec_len.data()), - input_ids_stride, - max_model_len, - draft_tokens_stride, - max_batch_size, - max_ngram_size, - max_draft_tokens); -} - -PD_BUILD_STATIC_OP(ngram_match) - .Inputs({"input_ids", - "input_ids_len", - "token_ids_all", - "prompt_lens", - "step_idx", - "draft_token_num", - "draft_tokens", - "seq_lens_this_time", - "seq_lens_encoder", - "seq_lens_decoder", - "max_dec_len"}) - .Attrs({"max_ngram_size: int", "max_draft_tokens: int"}) - .Outputs({"draft_tokens_out", "seq_lens_this_time_out"}) - .SetKernelFn(PD_KERNEL(NgramMatch)) - .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, - {"seq_lens_this_time", "seq_lens_this_time_out"}}); diff --git a/custom_ops/gpu_ops/speculate_decoding/ngram_match.cu b/custom_ops/gpu_ops/speculate_decoding/ngram_match.cu new file mode 100644 index 00000000000..cedd7f0b714 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/ngram_match.cu @@ -0,0 +1,345 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "paddle/extension.h" +#include "ngram_match_core.cuh" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +static __device__ int d_ngram_unprocessed_batch_size; + +// Phase 1: Block 0 counts unprocessed batches. +// Blocks 1..N find ngram candidates for each batch in parallel. +template +__global__ void ngram_count_and_find_candidate_kernel( + const int64_t *input_ids, + const int64_t *input_ids_len, + const int64_t *token_ids_all, + const int64_t *prompt_lens, + const int64_t *step_idx, + const int *draft_token_num, + int64_t *draft_tokens, + int64_t *draft_tokens_copy, + int32_t *seq_lens_this_time, + int32_t *seq_lens_this_time_copy, + int32_t *seq_lens_encoder, + int32_t *seq_lens_decoder, + int64_t *max_dec_len, + int64_t input_ids_stride, + int64_t max_model_len, + int64_t draft_tokens_stride, + int max_ngram_size, + int max_draft_tokens, + int32_t *unprocessed_batch_size_global, + int64_t max_batch_size) { + int tid = threadIdx.x; + int bid = blockIdx.x; + + // Block 0: count unprocessed batches + if (bid == 0) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int is_unprocessed = 0; + if (tid < max_batch_size) { + if (seq_lens_encoder[tid] > 0 || seq_lens_decoder[tid] > 0) { + is_unprocessed = 1; + } + } + int unprocessed_batch_size = BlockReduce(temp_storage).Sum(is_unprocessed); + if (tid == 0) { + *unprocessed_batch_size_global = unprocessed_batch_size; + } + return; + } + + int actual_bid = bid - 1; + if (actual_bid >= max_batch_size) return; + + __shared__ bool skip; + if (tid == 0) { + skip = false; + if (seq_lens_encoder[actual_bid] > 0) { + seq_lens_this_time_copy[actual_bid] = seq_lens_this_time[actual_bid]; + skip = true; + } else if (seq_lens_decoder[actual_bid] == 0) { + skip = true; + seq_lens_this_time_copy[actual_bid] = 0; + seq_lens_this_time[actual_bid] = 0; + } + } + __syncthreads(); + + if (skip) return; + + if (tid == 0) { + int64_t draft_token_num_val = + static_cast(draft_token_num[actual_bid]); + int64_t remaining_len = max_dec_len[actual_bid] - step_idx[actual_bid] - 1; + max_draft_tokens = draft_token_num_val < remaining_len + ? static_cast(draft_token_num_val) + : static_cast(remaining_len); + seq_lens_this_time_copy[actual_bid] = 1; + seq_lens_this_time[actual_bid] = 1; + } + + const int64_t *cur_input_ids = input_ids + actual_bid * input_ids_stride; + int64_t *cur_draft_tokens_copy = + draft_tokens_copy + actual_bid * draft_tokens_stride; + const int64_t *cur_pre_ids = + token_ids_all + actual_bid * max_model_len + prompt_lens[actual_bid]; + const int64_t cur_step_idx = step_idx[actual_bid]; + const int64_t cur_input_ids_len = input_ids_len[actual_bid]; + + __shared__ int64_t shared_match_idx; + + for (int ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) { + if (cur_step_idx < ngram_size) continue; + + const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size); + + // Search in input_ids + if (tid == 0) shared_match_idx = 0x7FFFFFFFFFFFFFFF; + __syncthreads(); + + sliding_window_search(cur_input_ids, + ngram, + cur_input_ids_len - ngram_size, + &shared_match_idx, + tid, + ngram_size); + __syncthreads(); + + if (shared_match_idx < 0x7FFFFFFFFFFFFFFF) { + if (tid == 0) { + int64_t start_idx = shared_match_idx + ngram_size; + int64_t end_idx_cand = start_idx + max_draft_tokens; + int64_t end_idx = + end_idx_cand < cur_input_ids_len ? end_idx_cand : cur_input_ids_len; + if (start_idx < end_idx) { + int64_t cur_draft_token_num = end_idx - start_idx; + seq_lens_this_time_copy[actual_bid] = cur_draft_token_num + 1; + memcpy(cur_draft_tokens_copy + 1, + cur_input_ids + start_idx, + sizeof(int64_t) * cur_draft_token_num); + } + } + break; + } + + // Search in generated tokens (pre_ids) + if (tid == 0) shared_match_idx = 0x7FFFFFFFFFFFFFFF; + __syncthreads(); + + sliding_window_search(cur_pre_ids, + ngram, + cur_step_idx - ngram_size, + &shared_match_idx, + tid, + ngram_size); + __syncthreads(); + + if (shared_match_idx < 0x7FFFFFFFFFFFFFFF) { + if (tid == 0) { + int64_t start_idx = shared_match_idx + ngram_size; + int64_t end_idx_cand = start_idx + max_draft_tokens; + int64_t end_idx = + end_idx_cand < cur_step_idx ? end_idx_cand : cur_step_idx; + if (start_idx < end_idx) { + int64_t cur_draft_token_num = end_idx - start_idx; + seq_lens_this_time_copy[actual_bid] = cur_draft_token_num + 1; + memcpy(cur_draft_tokens_copy + 1, + cur_pre_ids + start_idx, + sizeof(int64_t) * cur_draft_token_num); + } + } + break; + } + } +} + +// Phase 2: Single block truncation with threshold. +template +__global__ void ngram_truncate_candidate( + const int64_t *step_idx, + const int *draft_token_num, + int64_t *max_dec_len, + int32_t *seq_lens_this_time, + int32_t *seq_lens_this_time_copy, + int64_t *draft_tokens, + int64_t *draft_tokens_copy, + int64_t draft_tokens_stride, + int64_t max_batch_size, + int max_draft_tokens, + int threshold, + int32_t *unprocessed_batch_size_global) { + int tid = threadIdx.x; + int is_processed = 0; + int allocating_token_num = 0; + + if (tid < max_batch_size) { + int64_t draft_token_num_val = static_cast(draft_token_num[tid]); + int64_t remaining_len = max_dec_len[tid] - step_idx[tid] - 1; + max_draft_tokens = draft_token_num_val < remaining_len + ? static_cast(draft_token_num_val) + : static_cast(remaining_len); + + if (seq_lens_this_time[tid] == 1) is_processed = 1; + if (seq_lens_this_time[tid] > 0) { + allocating_token_num = seq_lens_this_time_copy[tid]; // decoding phase + if (seq_lens_this_time[tid] > 1) + allocating_token_num = seq_lens_this_time[tid]; // prefilling phase + } + } + + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage_batch; + int processed_batch_size; + BlockScan(temp_storage_batch) + .InclusiveSum(is_processed, processed_batch_size); + __syncthreads(); + + __shared__ typename BlockScan::TempStorage temp_storage_token; + int sum_token_num; + BlockScan(temp_storage_token) + .InclusiveSum(allocating_token_num, sum_token_num); + + if (is_processed && tid < max_batch_size) { + sum_token_num = sum_token_num - allocating_token_num + 1; + int unprocessed_batch_size_tid = + *unprocessed_batch_size_global - processed_batch_size; + + if (sum_token_num + unprocessed_batch_size_tid < threshold - 1) { + int64_t *cur_draft_tokens = draft_tokens + tid * draft_tokens_stride; + int64_t *cur_draft_tokens_copy = + draft_tokens_copy + tid * draft_tokens_stride; + if (sum_token_num + max_draft_tokens + unprocessed_batch_size_tid > + threshold) { + max_draft_tokens = + threshold - sum_token_num - unprocessed_batch_size_tid; + memcpy(cur_draft_tokens + 1, + cur_draft_tokens_copy + 1, + sizeof(int64_t) * max_draft_tokens); + seq_lens_this_time[tid] = max_draft_tokens + 1; + } else { + memcpy(cur_draft_tokens + 1, + cur_draft_tokens_copy + 1, + sizeof(int64_t) * (seq_lens_this_time_copy[tid] - 1)); + seq_lens_this_time[tid] = seq_lens_this_time_copy[tid]; + } + } + } +} + +void NgramMatch(const paddle::Tensor &input_ids, + const paddle::Tensor &input_ids_len, + const paddle::Tensor &token_ids_all, + const paddle::Tensor &prompt_lens, + const paddle::Tensor &step_idx, + const paddle::Tensor &draft_token_num, + const paddle::Tensor &draft_tokens, + const paddle::Tensor &draft_tokens_copy, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_this_time_copy, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &max_dec_len, + const int max_ngram_size, + const int max_draft_tokens) { + auto input_ids_shape = input_ids.shape(); + const int64_t input_ids_stride = input_ids_shape[1]; + + const int64_t max_model_len = token_ids_all.shape()[1]; + + auto draft_tokens_shape = draft_tokens.shape(); + const int64_t draft_tokens_stride = draft_tokens_shape[1]; + + const int64_t max_batch_size = seq_lens_this_time.shape()[0]; + + int tokennum_threshold = 128; + char *env_var = getenv("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"); + if (env_var) { + tokennum_threshold = std::stoi(env_var); + } + + const int NTHREADS = 1024; + + int *d_unprocessed_ptr; + cudaGetSymbolAddress(reinterpret_cast(&d_unprocessed_ptr), + d_ngram_unprocessed_batch_size); + + ngram_count_and_find_candidate_kernel + <<>>( + input_ids.data(), + input_ids_len.data(), + token_ids_all.data(), + prompt_lens.data(), + step_idx.data(), + draft_token_num.data(), + const_cast(draft_tokens.data()), + const_cast(draft_tokens_copy.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_this_time_copy.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(max_dec_len.data()), + input_ids_stride, + max_model_len, + draft_tokens_stride, + max_ngram_size, + max_draft_tokens, + d_unprocessed_ptr, + max_batch_size); + + ngram_truncate_candidate<<<1, NTHREADS>>>( + step_idx.data(), + draft_token_num.data(), + const_cast(max_dec_len.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_this_time_copy.data()), + const_cast(draft_tokens.data()), + const_cast(draft_tokens_copy.data()), + draft_tokens_stride, + max_batch_size, + max_draft_tokens, + tokennum_threshold, + d_unprocessed_ptr); +} + +PD_BUILD_STATIC_OP(ngram_match) + .Inputs({"input_ids", + "input_ids_len", + "token_ids_all", + "prompt_lens", + "step_idx", + "draft_token_num", + "draft_tokens", + "draft_tokens_copy", + "seq_lens_this_time", + "seq_lens_this_time_copy", + "seq_lens_encoder", + "seq_lens_decoder", + "max_dec_len"}) + .Attrs({"max_ngram_size: int", "max_draft_tokens: int"}) + .Outputs({"draft_tokens_out", "seq_lens_this_time_out"}) + .SetKernelFn(PD_KERNEL(NgramMatch)) + .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}}); diff --git a/custom_ops/gpu_ops/speculate_decoding/ngram_match_core.cuh b/custom_ops/gpu_ops/speculate_decoding/ngram_match_core.cuh new file mode 100644 index 00000000000..dbdce4fd051 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/ngram_match_core.cuh @@ -0,0 +1,42 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// Shared device function: parallel sliding window ngram search. +// Searches for the first occurrence of ngram[0..ngram_size-1] within +// cur_input_ids[0..search_len+ngram_size-1] using multiple threads. +// The minimum matching index is atomically written to shared_start_idx. +__device__ __forceinline__ void sliding_window_search( + const int64_t* cur_input_ids, + const int64_t* ngram, + const int64_t search_len, + int64_t* shared_start_idx, + int tid, + int ngram_size) { + for (int64_t i = tid; i <= search_len; i += blockDim.x) { + bool match = true; + for (int j = 0; j < ngram_size; ++j) { + if (ngram[j] != cur_input_ids[i + j]) { + match = false; + break; + } + } + if (match) { + atomicMin(reinterpret_cast(shared_start_idx), + static_cast(i)); + break; + } + } +} diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 88c1bbc5614..807aaf88eeb 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -1217,28 +1217,27 @@ def _update_status(self): ) def _extend_draft_token_with_ngram_match(self): - # TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency - device = paddle.CUDAPinnedPlace() + # Lazy initialization of GPU copy buffers + if not hasattr(self, "_ngram_draft_tokens_copy") or self._ngram_draft_tokens_copy is None: + self._ngram_draft_tokens_copy = paddle.zeros_like(self.target_model_inputs["draft_tokens"]) + self._ngram_seq_lens_this_time_copy = paddle.zeros_like(self.target_model_inputs["seq_lens_this_time"]) - draft_tokens = self.target_model_inputs["draft_tokens"].cpu() - seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu() - seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu() hybrid_mtp_ngram( - self.model_inputs["input_ids_cpu"], - self.model_inputs["input_ids_len"], - self.model_inputs["pre_ids"]._copy_to(device, True), - self.model_inputs["step_idx"].cpu(), - self.target_model_inputs["actual_draft_token_num"].cpu(), - draft_tokens, - seq_lens_this_time, - seq_lens_decoder, - self.model_inputs["max_dec_len"].cpu(), + self.model_inputs["input_ids"], + self.model_inputs["input_ids_len"].cuda(), + self.model_inputs["pre_ids"], + self.model_inputs["step_idx"], + self.target_model_inputs["actual_draft_token_num"], + self.target_model_inputs["draft_tokens"], + self._ngram_draft_tokens_copy, + self.target_model_inputs["seq_lens_this_time"], + self._ngram_seq_lens_this_time_copy, + self.model_inputs["seq_lens_decoder"], + self.model_inputs["max_dec_len"], self.max_ngram_size, self.min_ngram_size, self.max_draft_token_num, ) - self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda() - self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() def _run_impl( self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False, is_dummy_run: bool = False diff --git a/fastdeploy/spec_decode/ngram.py b/fastdeploy/spec_decode/ngram.py index b64e8fb5790..9f78a141c84 100644 --- a/fastdeploy/spec_decode/ngram.py +++ b/fastdeploy/spec_decode/ngram.py @@ -36,7 +36,9 @@ class NgramProposer(Proposer): def __init__(self, fd_config: "FDConfig"): super().__init__(fd_config) self.max_ngram_size = self.speculative_config.max_ngram_size - self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu() + self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64") + self._draft_tokens_copy = None + self._seq_lens_this_time_copy = None def update(self, bid: int, seq_len: int): """ @@ -48,26 +50,25 @@ def _run_impl(self, share_inputs): """ run """ - draft_tokens = share_inputs["draft_tokens"].cpu() - seq_lens_this_time = share_inputs["seq_lens_this_time"].cpu() - seq_lens_encoder = share_inputs["seq_lens_encoder"].cpu() - seq_lens_decoder = share_inputs["seq_lens_decoder"].cpu() + # Lazy initialization of GPU copy buffers + if self._draft_tokens_copy is None: + self._draft_tokens_copy = paddle.zeros_like(share_inputs["draft_tokens"]) + self._seq_lens_this_time_copy = paddle.zeros_like(share_inputs["seq_lens_this_time"]) ngram_match( - share_inputs["input_ids_cpu"], - self.input_ids_len.cpu(), - share_inputs["token_ids_all"].cpu(), - share_inputs["prompt_lens"].cpu(), - share_inputs["step_idx"].cpu(), - share_inputs["actual_draft_token_num"].cpu(), - draft_tokens, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - share_inputs["max_dec_len"].cpu(), + share_inputs["input_ids"], + self.input_ids_len, + share_inputs["token_ids_all"], + share_inputs["prompt_lens"], + share_inputs["step_idx"], + share_inputs["actual_draft_token_num"], + share_inputs["draft_tokens"], + self._draft_tokens_copy, + share_inputs["seq_lens_this_time"], + self._seq_lens_this_time_copy, + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["max_dec_len"], self.max_ngram_size, self.max_draft_token_num, ) - share_inputs["draft_tokens"][:] = draft_tokens.cuda() - share_inputs["seq_lens_encoder"][:] = seq_lens_encoder.cuda() - share_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() diff --git a/tests/operators/test_hybrid_mtp_ngram.py b/tests/operators/test_hybrid_mtp_ngram.py index 6c111f93763..30a72f8671b 100644 --- a/tests/operators/test_hybrid_mtp_ngram.py +++ b/tests/operators/test_hybrid_mtp_ngram.py @@ -24,6 +24,7 @@ class TestNgramMatchMixed(unittest.TestCase): def setUp(self): + paddle.set_device("gpu") self.max_bsz = 2 self.max_draft_tokens = 5 self.max_len = 32 @@ -31,34 +32,39 @@ def setUp(self): self.max_ngram_size = 5 self.min_ngram_size = 2 - # 初始化输入 tensor - self.input_ids = paddle.full(shape=[self.max_bsz, self.max_len], fill_value=-1, dtype="int64").cpu() - self.input_ids_len = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int64").cpu() - self.pre_ids = paddle.full(shape=[self.max_bsz, self.max_len], fill_value=-1, dtype="int64").cpu() - self.step_idx = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int64").cpu() - self.draft_token_num = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int32").cpu() + # 初始化输入 tensor (on GPU) + self.input_ids = paddle.full(shape=[self.max_bsz, self.max_len], fill_value=-1, dtype="int64") + self.input_ids_len = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int64") + self.pre_ids = paddle.full(shape=[self.max_bsz, self.max_len], fill_value=-1, dtype="int64") + self.step_idx = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int64") + self.draft_token_num = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int32") self.draft_tokens = paddle.full( shape=[self.max_bsz, self.max_draft_tokens + 1], fill_value=-1, dtype="int64", - ).cpu() - self.seq_lens_this_time = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int32").cpu() - self.seq_lens_decoder = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int32").cpu() + ) + self.draft_tokens_copy = paddle.zeros( + shape=[self.max_bsz, self.max_draft_tokens + 1], + dtype="int64", + ) + self.seq_lens_this_time = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int32") + self.seq_lens_this_time_copy = paddle.zeros(shape=[self.max_bsz, 1], dtype="int32") + self.seq_lens_decoder = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int32") self.max_dec_len = paddle.full( shape=[self.max_bsz, 1], fill_value=self.max_dec_len, dtype="int64", - ).cpu() + ) # 设置具体数据 - self.input_ids[:, :10] = np.arange(0, 10) + self.input_ids[:, :10] = paddle.to_tensor(np.arange(0, 10), dtype="int64") self.input_ids_len[:] = 10 - pre_ids_np = np.array([10, 9, 8, 7, 6, 10, 9, 8, 7], dtype="int32") - self.pre_ids[:, : pre_ids_np.shape[0]] = pre_ids_np + pre_ids_np = np.array([10, 9, 8, 7, 6, 10, 9, 8, 7], dtype="int64") + self.pre_ids[:, : pre_ids_np.shape[0]] = paddle.to_tensor(pre_ids_np, dtype="int64") self.step_idx[:] = 8 self.draft_token_num[:] = 5 - self.draft_tokens[:, :2] = np.array([8, 7]) + self.draft_tokens[:, :2] = paddle.to_tensor(np.array([8, 7]), dtype="int64") self.seq_lens_this_time[:] = 2 self.seq_lens_decoder[:] = 12 self.max_dec_len[:] = 512 @@ -75,7 +81,9 @@ def test_ngram_match_mixed(self): self.step_idx, self.draft_token_num, self.draft_tokens, + self.draft_tokens_copy, self.seq_lens_this_time, + self.seq_lens_this_time_copy, self.seq_lens_decoder, self.max_dec_len, self.max_ngram_size, @@ -83,8 +91,8 @@ def test_ngram_match_mixed(self): self.max_draft_tokens, ) - np.testing.assert_allclose(self.seq_lens_this_time.numpy(), self.ref_seq_lens_this_time) - np.testing.assert_allclose(self.draft_tokens.numpy(), self.ref_draft_tokens) + np.testing.assert_allclose(self.seq_lens_this_time.cpu().numpy(), self.ref_seq_lens_this_time) + np.testing.assert_allclose(self.draft_tokens.cpu().numpy(), self.ref_draft_tokens) if __name__ == "__main__": diff --git a/tests/operators/test_ngram_match.py b/tests/operators/test_ngram_match.py index 139b487de53..923c152f292 100644 --- a/tests/operators/test_ngram_match.py +++ b/tests/operators/test_ngram_match.py @@ -22,7 +22,7 @@ class TestNgramMatchOp(unittest.TestCase): def setUp(self): - paddle.set_device("cpu") + paddle.set_device("gpu") def test_basic_match(self): """ @@ -44,9 +44,11 @@ def test_basic_match(self): draft_token_num = paddle.to_tensor([3], dtype="int32") # Placeholder for draft tokens draft_tokens = paddle.zeros([batch_size, seq_len], dtype="int64") + draft_tokens_copy = paddle.zeros([batch_size, seq_len], dtype="int64") # Sequence lengths for this time step seq_lens_this_time = paddle.zeros([batch_size], dtype="int32") + seq_lens_this_time_copy = paddle.zeros([batch_size], dtype="int32") # Sequence lengths for encoder seq_lens_encoder = paddle.zeros([batch_size], dtype="int32") # Sequence lengths for decoder @@ -62,7 +64,9 @@ def test_basic_match(self): step_idx, draft_token_num, draft_tokens, + draft_tokens_copy, seq_lens_this_time, + seq_lens_this_time_copy, seq_lens_encoder, seq_lens_decoder, max_dec_len, @@ -71,12 +75,12 @@ def test_basic_match(self): ) # Extract non-zero tokens and assert the results. - nonzero_tokens = draft_tokens.numpy()[0][draft_tokens.numpy()[0] != 0] + nonzero_tokens = draft_tokens.cpu().numpy()[0][draft_tokens.cpu().numpy()[0] != 0] expected_tokens = [50, 60] self.assertTrue((nonzero_tokens == expected_tokens).all()) # Check length - self.assertEqual(seq_lens_this_time.numpy()[0], 3) + self.assertEqual(seq_lens_this_time.cpu().numpy()[0], 3) def test_no_match(self): """ @@ -90,8 +94,10 @@ def test_no_match(self): step_idx = paddle.to_tensor([3], dtype="int64") draft_token_num = paddle.to_tensor([2], dtype="int32") draft_tokens = paddle.zeros([batch_size, 4], dtype="int64") + draft_tokens_copy = paddle.zeros([batch_size, 4], dtype="int64") seq_lens_this_time = paddle.zeros([batch_size], dtype="int32") + seq_lens_this_time_copy = paddle.zeros([batch_size], dtype="int32") seq_lens_encoder = paddle.zeros([batch_size], dtype="int32") seq_lens_decoder = paddle.ones([batch_size], dtype="int32") max_dec_len = paddle.to_tensor([6], dtype="int64") @@ -104,7 +110,9 @@ def test_no_match(self): step_idx, draft_token_num, draft_tokens, + draft_tokens_copy, seq_lens_this_time, + seq_lens_this_time_copy, seq_lens_encoder, seq_lens_decoder, max_dec_len, @@ -113,7 +121,7 @@ def test_no_match(self): ) # No match → should only keep 1 token - self.assertEqual(seq_lens_this_time.numpy()[0], 1) + self.assertEqual(seq_lens_this_time.cpu().numpy()[0], 1) if __name__ == "__main__": From 31fd29ddae50eeaff497085dc6fe4ac937da9dff Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 31 Mar 2026 15:03:16 +0800 Subject: [PATCH 2/2] update --- .../gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu | 1 + custom_ops/gpu_ops/speculate_decoding/ngram_match.cu | 1 + 2 files changed, 2 insertions(+) diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu b/custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu index 1bb18107e04..a11c801be3b 100644 --- a/custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu @@ -17,6 +17,7 @@ #include #include #include +#include "helper.h" #include "paddle/extension.h" #include "../ngram_match_core.cuh" diff --git a/custom_ops/gpu_ops/speculate_decoding/ngram_match.cu b/custom_ops/gpu_ops/speculate_decoding/ngram_match.cu index cedd7f0b714..908651d88e2 100644 --- a/custom_ops/gpu_ops/speculate_decoding/ngram_match.cu +++ b/custom_ops/gpu_ops/speculate_decoding/ngram_match.cu @@ -17,6 +17,7 @@ #include #include #include +#include "helper.h" #include "paddle/extension.h" #include "ngram_match_core.cuh"