diff --git a/custom_ops/xpu_ops/src/ops/mtp/build_sampling_params.cc b/custom_ops/xpu_ops/src/ops/mtp/build_sampling_params.cc new file mode 100644 index 00000000000..409b1c8a24c --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/build_sampling_params.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +namespace api = baidu::xpu::api; + +std::vector BuildSamplingParams( + const paddle::Tensor& top_p, + const paddle::Tensor& top_k, + paddle::Tensor& infer_seed, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const int64_t token_num_output_cpu, + const int64_t increment_value) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + api::Context* ctx = xpu_ctx->x_context(); + std::unique_ptr cpu_ctx; + if (top_p.is_cpu()) { + cpu_ctx = std::make_unique(api::kCPU); + ctx = cpu_ctx.get(); + } + + int real_bsz = static_cast(seq_lens_this_time.shape()[0]); + + auto top_p_padding = paddle::empty( + {token_num_output_cpu, 1}, paddle::DataType::FLOAT32, top_p.place()); + auto top_k_padding = paddle::empty( + {token_num_output_cpu, 1}, paddle::DataType::INT64, top_p.place()); + auto topp_seed = paddle::empty( + {token_num_output_cpu, 1}, paddle::DataType::INT64, top_p.place()); + + int r = + fastdeploy::plugin::build_sampling_params(ctx, + top_p_padding.data(), + top_k_padding.data(), + topp_seed.data(), + top_p.data(), + top_k.data(), + infer_seed.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + real_bsz, + token_num_output_cpu, + increment_value); + PD_CHECK(r == 0, "fastdeploy::plugin::build_sampling_params failed."); + + return {top_p_padding, top_k_padding, topp_seed}; +} + +std::vector> BuildSamplingParamsInferShape( + const std::vector& top_p_shape, + const std::vector& top_k_shape, + const std::vector& infer_seed_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& seq_lens_encoder_shape) { + // token_num is dynamic; return a placeholder shape of [-1, 1] + return {{-1, 1}, {-1, 1}, {-1, 1}}; +} + +std::vector BuildSamplingParamsInferDtype( + const paddle::DataType& top_p_dtype, + const paddle::DataType& top_k_dtype, + const paddle::DataType& infer_seed_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& seq_lens_encoder_dtype) { + return {paddle::DataType::FLOAT32, + paddle::DataType::INT64, + paddle::DataType::INT64}; +} + +PD_BUILD_STATIC_OP(build_sampling_params) + .Inputs({"top_p", + "top_k", + "infer_seed", + "seq_lens_this_time", + "seq_lens_encoder"}) + .Outputs({"top_p_padding", "top_k_padding", "topp_seed"}) + .Attrs({"token_num_output_cpu: int64_t", "increment_value: int64_t"}) + .SetKernelFn(PD_KERNEL(BuildSamplingParams)) + .SetInferShapeFn(PD_INFER_SHAPE(BuildSamplingParamsInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(BuildSamplingParamsInferDtype)); diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index acc3d0fe68f..2535076a81d 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -843,6 +843,19 @@ DLL_EXPORT int reasoning_phase_token_constraint( int max_seq_len, int allowed_tokens_len); +DLL_EXPORT int build_sampling_params(api::Context* ctx, + float* top_p_padding, + int64_t* top_k_padding, + int64_t* topp_seed, + const float* top_p, + const int64_t* top_k, + int64_t* infer_seed, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + int bs, + int64_t token_num, + int64_t increment_value); + /*--------------------------------------- MTP end * --------------------------------------------*/ diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/build_sampling_params.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/build_sampling_params.xpu new file mode 100644 index 00000000000..5aa1a39d93e --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/build_sampling_params.xpu @@ -0,0 +1,127 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace fd_xpu3 { + +constexpr int64_t BUILD_SAMPLING_MAX_INFER_SEED = 2147483646LL; + +// Each cluster handles one batch item (bi = cluster_id). +// Within the cluster, core 0 reads the per-batch scalars and broadcasts via +// shared memory; all cores then fill their assigned token slots in parallel. +__global__ void build_sampling_params_kernel( + __global_ptr__ float* top_p_padding, + __global_ptr__ int64_t* top_k_padding, + __global_ptr__ int64_t* topp_seed, + __global_ptr__ const float* top_p, + __global_ptr__ const int64_t* top_k, + __global_ptr__ int64_t* infer_seed, + __global_ptr__ const int* seq_lens_this_time, + __global_ptr__ const int* seq_lens_encoder, + int bs, + int64_t token_num, + int64_t increment_value) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + + // Shared scalars broadcast from core 0 to all cores in the cluster. + __shared__ float sm_top_p; + __shared__ int64_t sm_top_k; + __shared__ int64_t sm_seed; + __shared__ int sm_repeat; // number of tokens this batch produces + __shared__ int sm_pad_start; // starting index in the output buffer + + // Shared prefix-sum buffer: each cluster computes its own pad_start via + // a two-pass scan over seq_lens_this_time / seq_lens_encoder. + // We use a simple approach: core 0 of cluster 0 writes per-batch start + // offsets into a global scratch area is not available here, so instead we + // compute pad_start with a sequential scan in core 0 of each cluster. + // Because clusters run concurrently we cannot share a global accumulator; + // instead each cluster independently sums the first `bi` entries. + // This is O(bs) per cluster but bs is typically small (<=512). + + for (int bi = clusterid; bi < bs; bi += nclusters) { + if (cid == 0) { + // Read per-batch parameters from global memory. + float lm_top_p; + int64_t lm_top_k; + int64_t lm_seed; + int lm_slt; // seq_lens_this_time[bi] + int lm_sle; // seq_lens_encoder[bi] + + GM2LM_ASYNC(top_p + bi, &lm_top_p, sizeof(float)); + GM2LM_ASYNC(top_k + bi, &lm_top_k, sizeof(int64_t)); + GM2LM_ASYNC(infer_seed + bi, &lm_seed, sizeof(int64_t)); + GM2LM_ASYNC(seq_lens_this_time + bi, &lm_slt, sizeof(int)); + GM2LM(seq_lens_encoder + bi, &lm_sle, sizeof(int)); // sync barrier + + bool is_decoder = (lm_sle == 0); + int repeat = is_decoder ? lm_slt : 1; + + // Compute pad_start = sum of token counts for batches [0, bi). + int pad_start = 0; + for (int k = 0; k < bi; k++) { + int slt_k, sle_k; + GM2LM_ASYNC(seq_lens_this_time + k, &slt_k, sizeof(int)); + GM2LM(seq_lens_encoder + k, &sle_k, sizeof(int)); + pad_start += (sle_k == 0) ? slt_k : 1; + } + + sm_top_p = lm_top_p; + sm_top_k = lm_top_k; + sm_seed = lm_seed; + sm_repeat = repeat; + sm_pad_start = pad_start; + } + mfence(); + sync_all(); + + // All cores fill token slots [sm_pad_start, sm_pad_start + sm_repeat). + float bi_top_p = sm_top_p; + int64_t bi_top_k = sm_top_k; + int64_t bi_seed = sm_seed; + int repeat = sm_repeat; + int pad_start = sm_pad_start; + + for (int local_pos = cid; local_pos < repeat; local_pos += ncores) { + int pad_idx = pad_start + local_pos; + float lm_top_p_out = bi_top_p; + int64_t lm_top_k_out = bi_top_k; + // Decoder tokens: offset seed by position; encoder token: no offset. + int64_t offset = static_cast(local_pos) * 4; + int64_t lm_seed_out = (bi_seed + offset) % BUILD_SAMPLING_MAX_INFER_SEED; + + LM2GM_ASYNC(&lm_top_p_out, top_p_padding + pad_idx, sizeof(float)); + LM2GM_ASYNC(&lm_top_k_out, top_k_padding + pad_idx, sizeof(int64_t)); + LM2GM(&lm_seed_out, topp_seed + pad_idx, sizeof(int64_t)); + } + + // Core 0 updates infer_seed in-place. + if (cid == 0) { + int64_t new_seed = + (bi_seed + increment_value) % BUILD_SAMPLING_MAX_INFER_SEED; + LM2GM(&new_seed, infer_seed + bi, sizeof(int64_t)); + } + + mfence(); + sync_all(); + } +} + +} // namespace fd_xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/build_sampling_params.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/build_sampling_params.cpp new file mode 100644 index 00000000000..c2da64c44fc --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/build_sampling_params.cpp @@ -0,0 +1,162 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace fd_xpu3 { +__attribute__((global)) void build_sampling_params_kernel( + float* top_p_padding, + int64_t* top_k_padding, + int64_t* topp_seed, + const float* top_p, + const int64_t* top_k, + int64_t* infer_seed, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + int bs, + int64_t token_num, + int64_t increment_value); +} // namespace fd_xpu3 + +namespace fastdeploy { +namespace plugin { + +constexpr int64_t BUILD_SAMPLING_MAX_INFER_SEED = 2147483646LL; + +static int cpu_wrapper(api::Context* ctx, + float* top_p_padding, + int64_t* top_k_padding, + int64_t* topp_seed, + const float* top_p, + const int64_t* top_k, + int64_t* infer_seed, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + int bs, + int64_t token_num, + int64_t increment_value) { + int64_t pad_idx = 0; + for (int bi = 0; bi < bs; bi++) { + bool is_decoder = (seq_lens_encoder[bi] == 0); + int repeat = is_decoder ? seq_lens_this_time[bi] : 1; + int64_t bi_seed = infer_seed[bi]; + for (int local_pos = 0; local_pos < repeat; local_pos++) { + int64_t offset = is_decoder ? static_cast(local_pos) * 4 : 0LL; + top_p_padding[pad_idx] = top_p[bi]; + top_k_padding[pad_idx] = top_k[bi]; + topp_seed[pad_idx] = (bi_seed + offset) % BUILD_SAMPLING_MAX_INFER_SEED; + pad_idx++; + } + infer_seed[bi] = + (infer_seed[bi] + increment_value) % BUILD_SAMPLING_MAX_INFER_SEED; + } + return api::SUCCESS; +} + +static int xpu3_wrapper(api::Context* ctx, + float* top_p_padding, + int64_t* top_k_padding, + int64_t* topp_seed, + const float* top_p, + const int64_t* top_k, + int64_t* infer_seed, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + int bs, + int64_t token_num, + int64_t increment_value) { + using XPU_INT64 = typename api::XPUIndexType::type; + int32_t ret_xre = fd_xpu3:: + build_sampling_params_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + top_p_padding, + reinterpret_cast(top_k_padding), + reinterpret_cast(topp_seed), + top_p, + reinterpret_cast(top_k), + reinterpret_cast(infer_seed), + seq_lens_this_time, + seq_lens_encoder, + bs, + token_num, + increment_value); + KERNEL_ASSERT_SUCCESS(ctx, ret_xre); + return api::SUCCESS; +} + +int build_sampling_params(api::Context* ctx, + float* top_p_padding, + int64_t* top_k_padding, + int64_t* topp_seed, + const float* top_p, + const int64_t* top_k, + int64_t* infer_seed, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + int bs, + int64_t token_num, + int64_t increment_value) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "build_sampling_params", float); + WRAPPER_DUMP_PARAM5( + ctx, top_p_padding, top_k_padding, topp_seed, top_p, top_k); + WRAPPER_DUMP_PARAM5( + ctx, infer_seed, seq_lens_this_time, seq_lens_encoder, bs, token_num); + WRAPPER_DUMP_PARAM1(ctx, increment_value); + WRAPPER_DUMP(ctx); + + WRAPPER_CHECK_PTR(ctx, float, token_num, top_p_padding); + WRAPPER_CHECK_PTR(ctx, int64_t, token_num, top_k_padding); + WRAPPER_CHECK_PTR(ctx, int64_t, token_num, topp_seed); + WRAPPER_CHECK_PTR(ctx, float, bs, top_p); + WRAPPER_CHECK_PTR(ctx, int64_t, bs, top_k); + WRAPPER_CHECK_PTR(ctx, int64_t, bs, infer_seed); + WRAPPER_CHECK_PTR(ctx, int, bs, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int, bs, seq_lens_encoder); + + WRAPPER_ASSERT_GT(ctx, bs, 0); + WRAPPER_ASSERT_GT(ctx, token_num, 0); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + top_p_padding, + top_k_padding, + topp_seed, + top_p, + top_k, + infer_seed, + seq_lens_this_time, + seq_lens_encoder, + bs, + token_num, + increment_value); + } else if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + top_p_padding, + top_k_padding, + topp_seed, + top_p, + top_k, + infer_seed, + seq_lens_this_time, + seq_lens_encoder, + bs, + token_num, + increment_value); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace fastdeploy diff --git a/custom_ops/xpu_ops/test/test_build_sampling_params.py b/custom_ops/xpu_ops/test/test_build_sampling_params.py new file mode 100644 index 00000000000..ab53bd48b3c --- /dev/null +++ b/custom_ops/xpu_ops/test/test_build_sampling_params.py @@ -0,0 +1,275 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for build_sampling_params XPU op. + +Verifies that the XPU kernel produces the same output as the Python reference +implementation (padding_sampling_params) for all cases: + - pure decoder batches (seq_lens_encoder == 0) + - pure encoder batches (seq_lens_encoder > 0) + - mixed encoder/decoder batches + - single-item batch (bs=1) + - seed wrap-around near MAX_INFER_SEED +""" + +import unittest + +import numpy as np +import paddle + +DEVICE_PLACE = paddle.XPUPlace(0) if paddle.is_compiled_with_xpu() else paddle.CPUPlace() +MAX_INFER_SEED = 2147483646 + + +# --------------------------------------------------------------------------- +# Python reference implementation (mirrors sampler.py padding_sampling_params) +# --------------------------------------------------------------------------- + + +def ref_build_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder, increment_value): + """ + Pure-Python reference that mirrors the cpu_wrapper logic in + build_sampling_params.cpp. + + Returns (top_p_padding, top_k_padding, topp_seed) as numpy arrays of + shape [token_num, 1], and infer_seed updated in-place. + """ + bs = len(seq_lens_this_time) + infer_seed = infer_seed.copy() # don't mutate the input + + top_p_out, top_k_out, seed_out = [], [], [] + for bi in range(bs): + is_decoder = seq_lens_encoder[bi] == 0 + repeat = int(seq_lens_this_time[bi]) if is_decoder else 1 + bi_seed = int(infer_seed[bi]) + for local_pos in range(repeat): + offset = local_pos * 4 if is_decoder else 0 + top_p_out.append([top_p[bi]]) + top_k_out.append([top_k[bi]]) + seed_out.append([(bi_seed + offset) % MAX_INFER_SEED]) + infer_seed[bi] = (bi_seed + increment_value) % MAX_INFER_SEED + + top_p_out = np.array(top_p_out, dtype=np.float32) + top_k_out = np.array(top_k_out, dtype=np.int64) + seed_out = np.array(seed_out, dtype=np.int64) + return top_p_out, top_k_out, seed_out, infer_seed + + +# --------------------------------------------------------------------------- +# Helper: run the XPU op and return numpy results +# --------------------------------------------------------------------------- + + +def run_op(top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder, increment_value): + from fastdeploy.model_executor.ops.xpu import build_sampling_params + + token_num = int( + sum(seq_lens_this_time[i] if seq_lens_encoder[i] == 0 else 1 for i in range(len(seq_lens_this_time))) + ) + + tp = paddle.to_tensor(top_p, place=DEVICE_PLACE) + tk = paddle.to_tensor(top_k, place=DEVICE_PLACE) + seed = paddle.to_tensor(infer_seed.copy(), place=DEVICE_PLACE) + slt = paddle.to_tensor(seq_lens_this_time, place=DEVICE_PLACE) + sle = paddle.to_tensor(seq_lens_encoder, place=DEVICE_PLACE) + + tp_pad, tk_pad, seed_pad = build_sampling_params( + tp, + tk, + seed, + slt, + sle, + token_num_output_cpu=token_num, + increment_value=increment_value, + ) + return (tp_pad.numpy(), tk_pad.numpy(), seed_pad.numpy(), seed.numpy()) # seed was updated in-place inside the op + + +# --------------------------------------------------------------------------- +# Assertion helper +# --------------------------------------------------------------------------- + + +def assert_close(ref, got, name, rtol=1e-5, atol=1e-5): + assert ref.shape == got.shape, f"[{name}] shape mismatch: ref={ref.shape} got={got.shape}" + if ref.dtype in (np.float32, np.float64): + ok = np.allclose(ref, got, rtol=rtol, atol=atol) + else: + ok = np.array_equal(ref, got) + assert ok, f"[{name}] value mismatch.\n" f"ref=\n{ref}\ngot=\n{got}\ndiff=\n{ref - got}" + + +# --------------------------------------------------------------------------- +# Test cases +# --------------------------------------------------------------------------- + + +class TestBuildSamplingParams(unittest.TestCase): + + def _run_and_compare(self, top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder, increment_value): + try: + tp_xpu, tk_xpu, sd_xpu, seed_xpu = run_op( + top_p, + top_k, + infer_seed, + seq_lens_this_time, + seq_lens_encoder, + increment_value, + ) + except ImportError as e: + self.skipTest(f"XPU op not available: {e}") + + tp_ref, tk_ref, sd_ref, seed_ref = ref_build_sampling_params( + top_p, + top_k, + infer_seed, + seq_lens_this_time, + seq_lens_encoder, + increment_value, + ) + + assert_close(tp_ref, tp_xpu, "top_p_padding") + assert_close(tk_ref, tk_xpu, "top_k_padding") + assert_close(sd_ref, sd_xpu, "topp_seed") + assert_close(seed_ref, seed_xpu, "infer_seed") + + # ------------------------------------------------------------------ + # Test 1: pure decoder batch (all seq_lens_encoder == 0) + # ------------------------------------------------------------------ + def test_pure_decoder(self): + top_p = np.array([0.9, 0.8, 0.7], dtype=np.float32) + top_k = np.array([50, 40, 30], dtype=np.int64) + infer_seed = np.array([100, 200, 300], dtype=np.int64) + seq_lens_this_time = np.array([4, 3, 2], dtype=np.int32) + seq_lens_encoder = np.array([0, 0, 0], dtype=np.int32) + increment_value = 16 # token_num * 4 = (4+3+2)*4 is typical + + self._run_and_compare( + top_p, + top_k, + infer_seed, + seq_lens_this_time, + seq_lens_encoder, + increment_value, + ) + + # ------------------------------------------------------------------ + # Test 2: pure encoder batch (all seq_lens_encoder > 0) + # -> each batch contributes exactly 1 output token, no seed offset + # ------------------------------------------------------------------ + def test_pure_encoder(self): + top_p = np.array([0.95, 0.85], dtype=np.float32) + top_k = np.array([10, 20], dtype=np.int64) + infer_seed = np.array([1000, 2000], dtype=np.int64) + seq_lens_this_time = np.array([5, 7], dtype=np.int32) + seq_lens_encoder = np.array([5, 7], dtype=np.int32) # all encoder + increment_value = 8 + + self._run_and_compare( + top_p, + top_k, + infer_seed, + seq_lens_this_time, + seq_lens_encoder, + increment_value, + ) + + # ------------------------------------------------------------------ + # Test 3: mixed encoder/decoder + # ------------------------------------------------------------------ + def test_mixed(self): + top_p = np.array([0.9, 0.8, 0.7, 0.6], dtype=np.float32) + top_k = np.array([50, 40, 30, 20], dtype=np.int64) + infer_seed = np.array([10, 20, 30, 40], dtype=np.int64) + seq_lens_this_time = np.array([3, 4, 2, 5], dtype=np.int32) + # batch 0,2 are decoder; batch 1,3 are encoder + seq_lens_encoder = np.array([0, 4, 0, 5], dtype=np.int32) + increment_value = 20 + + self._run_and_compare( + top_p, + top_k, + infer_seed, + seq_lens_this_time, + seq_lens_encoder, + increment_value, + ) + + # ------------------------------------------------------------------ + # Test 4: bs=1 (single item) + # ------------------------------------------------------------------ + def test_single_item(self): + top_p = np.array([0.5], dtype=np.float32) + top_k = np.array([5], dtype=np.int64) + infer_seed = np.array([42], dtype=np.int64) + seq_lens_this_time = np.array([6], dtype=np.int32) + seq_lens_encoder = np.array([0], dtype=np.int32) + increment_value = 24 + + self._run_and_compare( + top_p, + top_k, + infer_seed, + seq_lens_this_time, + seq_lens_encoder, + increment_value, + ) + + # ------------------------------------------------------------------ + # Test 5: seed near wrap-around boundary + # ------------------------------------------------------------------ + def test_seed_wraparound(self): + # Seeds close to MAX_INFER_SEED to trigger modulo wrap + near_max = MAX_INFER_SEED - 8 + top_p = np.array([0.9, 0.9], dtype=np.float32) + top_k = np.array([50, 50], dtype=np.int64) + infer_seed = np.array([near_max, near_max - 1], dtype=np.int64) + seq_lens_this_time = np.array([4, 4], dtype=np.int32) + seq_lens_encoder = np.array([0, 0], dtype=np.int32) + increment_value = 16 + + self._run_and_compare( + top_p, + top_k, + infer_seed, + seq_lens_this_time, + seq_lens_encoder, + increment_value, + ) + + # ------------------------------------------------------------------ + # Test 6: seq_lens_this_time == 1 for all decoder batches + # (degenerate case: each decoder produces exactly one token) + # ------------------------------------------------------------------ + def test_single_token_per_batch(self): + top_p = np.array([0.9, 0.8, 0.7], dtype=np.float32) + top_k = np.array([50, 40, 30], dtype=np.int64) + infer_seed = np.array([1, 2, 3], dtype=np.int64) + seq_lens_this_time = np.array([1, 1, 1], dtype=np.int32) + seq_lens_encoder = np.array([0, 0, 0], dtype=np.int32) + increment_value = 4 + + self._run_and_compare( + top_p, + top_k, + infer_seed, + seq_lens_this_time, + seq_lens_encoder, + increment_value, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 3452da1bc85..abf106dca9a 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -60,6 +60,12 @@ build_sampling_params_logprob, naive_update_model_status, ) +else: + from fastdeploy.model_executor.ops.xpu import ( + build_sampling_params, + top_p_candidates, + verify_draft_tokens, + ) def _apply_triton_top_k_top_p( @@ -1181,19 +1187,12 @@ def _normal_sample_xpu( share_inputs: List[paddle.Tensor], ) -> SamplerOutput: """Normal sampling for NAIVE mode on XPU.""" - top_p, top_k, topp_seed = padding_sampling_params( - sampling_metadata.top_p, - sampling_metadata.top_k, - sampling_metadata.seed, - paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]), - paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]), - ) _, next_tokens = top_k_top_p_sampling( probs, - top_p=top_p, - top_k=top_k, + top_p=sampling_metadata.top_p, + top_k=sampling_metadata.top_k, top_k_list=sampling_metadata.top_k_list, - topp_seed=topp_seed, + topp_seed=sampling_metadata.topp_seed, ) real_bsz = share_inputs["seq_lens_this_time"].shape[0] running_mask = (paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]) > 0).cast("int32") @@ -1213,25 +1212,24 @@ def _verify_and_sample_xpu( sampling_metadata: SamplingMetadata, max_model_len: int, share_inputs: List[paddle.Tensor], + increment_value: int, accept_all_drafts: bool = False, reject_all_drafts: bool = False, ) -> SamplerOutput: """Verify draft tokens (MTP/Ngram mode) on XPU using verify_draft_tokens.""" - from fastdeploy.model_executor.ops.xpu import ( - top_p_candidates, - verify_draft_tokens, - ) target_tokens = None candidate_ids, candidate_scores, candidate_lens = None, None, None if self.verify_strategy == VerifyStrategy.TARGET_MATCH: - top_p, top_k, topp_seed = padding_sampling_params( + top_p, top_k, topp_seed = build_sampling_params( sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.seed, - paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]), - paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]), + share_inputs["seq_lens_this_time"], + share_inputs["seq_lens_encoder"], + token_num_output_cpu=int(share_inputs["cu_seqlens_q_output"][-1]), + increment_value=increment_value, ) _, target_tokens = top_k_top_p_sampling( probs, @@ -1293,6 +1291,7 @@ def forward_xpu( sampling_metadata: SamplingMetadata, max_model_len: int, share_inputs: List[paddle.Tensor], + increment_value: int, accept_all_drafts: bool = False, reject_all_drafts: bool = False, ) -> SamplerOutput: @@ -1346,6 +1345,7 @@ def forward_xpu( sampling_metadata, max_model_len, share_inputs, + increment_value, accept_all_drafts, reject_all_drafts, ) diff --git a/fastdeploy/model_executor/xpu_pre_and_post_process.py b/fastdeploy/model_executor/xpu_pre_and_post_process.py index bdce35eca0c..ee1b4412a2d 100644 --- a/fastdeploy/model_executor/xpu_pre_and_post_process.py +++ b/fastdeploy/model_executor/xpu_pre_and_post_process.py @@ -137,8 +137,12 @@ def xpu_pre_process( ) = speculate_pre_process( token_num_cpu, input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, seq_lens_decoder ) - share_inputs["cu_seqlens_q_output"] = cu_seqlens_q_output - share_inputs["batch_id_per_token_output"] = batch_id_per_token_output + if use_cudagraph: + share_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False) + share_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False) + else: + share_inputs["cu_seqlens_q_output"] = cu_seqlens_q_output + share_inputs["batch_id_per_token_output"] = batch_id_per_token_output else: ( ids_remove_padding, diff --git a/fastdeploy/spec_decode/mtp_xpu.py b/fastdeploy/spec_decode/mtp_xpu.py index 1f762deb73c..4721b0e192e 100644 --- a/fastdeploy/spec_decode/mtp_xpu.py +++ b/fastdeploy/spec_decode/mtp_xpu.py @@ -115,6 +115,14 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_ru for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) + # 1. CUDA Graph capture sizes must be recorded in descending order (large → small). + # 2. In multi-step execution, only the first step should be captured. + # self.forward_meta.step_use_cudagraph = ( + # step_use_cudagraph and self.draft_model_use_cudagraph and not (substep > 0 and is_dummy_run) + # ) + # TODO(chenhuan09): support cudagraph for draft model + self.forward_meta.step_use_cudagraph = False + def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0): """ Main process for MTP inference. @@ -122,7 +130,22 @@ def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, step_use_cudagraph: bool Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP. """ + # TODO(chenhuan09):remove not_need_stop + # is_blocking = ( + # (not self.fd_config.scheduler_config.enable_overlap_schedule) + # or is_dummy_run + # or self.exist_prefill() + # or real_bsz == 0 # always True + # ) for substep in range(self.num_model_steps): + # if is_blocking: + # token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item() + # else: + # if substep == 0: + # token_num_cpu = self.model_inputs["target_hidden_states"].shape[0] + # else: + # token_num_cpu = real_bsz + # if token_num_cpu > 0: if self.model_inputs["not_need_stop"]: self.model_inputs["substep"] = substep # Remove padding @@ -155,7 +178,12 @@ def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, ) self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False) - self._initialize_forward_meta() + self._initialize_forward_meta( + step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep + ) + # Padding inputs for cuda graph + self.padding_cudagraph_inputs() + # Get sampling metadata self.sampling_metadata = SamplingMetadata( temperature=self.model_inputs["temperature"], @@ -178,12 +206,15 @@ def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, if self.num_model_steps > 1: self.model_inputs.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"]) - + real_num = self.model_inputs["ids_remove_padding"].shape[0] + target_hidden_states = self.model_inputs["target_hidden_states"][:real_num] model_output = self.model( ids_remove_padding=self.model_inputs["ids_remove_padding"], - previous_hidden_states=self.model_inputs["target_hidden_states"], + previous_hidden_states=target_hidden_states, forward_meta=self.forward_meta, ) + if self.forward_meta.step_use_cudagraph: + model_output = model_output[: self.real_token_num] hidden_states = xpu_process_output(model_output, self.forward_meta, self.model_inputs) # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta) @@ -323,3 +354,16 @@ def _update_status(self): self.cache_config.block_size, self.max_draft_token_num, ) + + def padding_cudagraph_inputs(self) -> None: + """ + Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. + In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch. + """ + # In init_attention_metadata, the decode buffer has already been cleared + + # To adapt to CUDA Graph, keep the forward pass at the maximum batch size. + if self.forward_meta.step_use_cudagraph: + self.forward_meta.seq_lens_this_time = self.model_inputs["seq_lens_this_time"] + self.real_token_num = self.forward_meta.ids_remove_padding.shape[0] + return diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 9f3eace566f..c1d3e0378de 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -173,9 +173,12 @@ def __init__( self.share_inputs.init_share_inputs() self.max_num_seqs = self.fd_config.scheduler_config.max_num_seqs + self.increment_value = ( + 4 if not self.speculative_decoding else (self.speculative_config.num_speculative_tokens + 1) * 4 + ) self.infer_seed_increment = paddle.full( shape=[self.scheduler_config.max_num_seqs, 1], - fill_value=4, + fill_value=self.increment_value, dtype="int64", ).cpu() @@ -849,22 +852,8 @@ def _prepare_inputs(self, is_dummy_run=False) -> None: if self.use_cudagraph: # Update Batch type for cuda graph for only_decode_batch if_only_decode = self.only_decode() - - only_decode_use_cudagraph = self.use_cudagraph and if_only_decode - # Update config about moe for better performance - # TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply() - if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": - self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill" - if self.speculative_decoding: - self.proposer.fd_config.parallel_config.moe_phase.phase = "decode" if if_only_decode else "prefill" - - # Update Batch type for cuda graph for only_prefill_batch - only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill() - self.forward_meta.step_use_cudagraph = ( - only_prefill_use_cudagraph - if self.cudagraph_only_prefill - else only_decode_use_cudagraph and self.forward_meta.ids_remove_padding.shape[0] > 0 + self.use_cudagraph and if_only_decode and self.forward_meta.ids_remove_padding.shape[0] > 0 ) # Update bad tokens len @@ -876,9 +865,7 @@ def _prepare_inputs(self, is_dummy_run=False) -> None: if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query": self.forward_meta.kv_signal_sender = self.share_inputs["kv_signal_sender"] - if ( - self.fd_config.scheduler_config.splitwise_role == "mixed" and envs.FD_XPU_ENABLE_MIXED_EP_MODE - ): # Centralized scenario: the phase is initialized as "prefill" by default. During inference runtime, different types of batches can achieve phase switching at this point. + if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": if_only_decode = self.only_decode() self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill" @@ -1137,6 +1124,7 @@ def _dummy_run( batch_size: paddle.Tensor, expected_decode_len: int = 1, in_capturing: bool = False, + accept_all_drafts=False, ) -> paddle.Tensor: """ Use dummy inputs to run before formal execution. @@ -1165,7 +1153,7 @@ def _dummy_run( ) while True: - self.execute_model(is_dummy_run=True, in_capturing=in_capturing) + self.execute_model(is_dummy_run=True, in_capturing=in_capturing, accept_all_drafts=accept_all_drafts) if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0: break @@ -1214,14 +1202,30 @@ def capture_model(self) -> None: capture_sizes = self.cudagraph_capture_sizes.copy() try: - for batch_size in sorted(capture_sizes, reverse=True): - self._dummy_run( - num_tokens=self.scheduler_config.max_num_batched_tokens, - batch_size=batch_size, - expected_decode_len=expected_decode_len, - in_capturing=True, - ) - logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}") + if self.speculative_decoding and self.spec_method in [SpecMethod.MTP, SpecMethod.SUFFIX]: + for capture_size in sorted(capture_sizes, reverse=True): + expected_decode_len = (self.speculative_config.num_speculative_tokens + 1) * 2 + self._dummy_run( + num_tokens=self.fd_config.get_max_chunk_tokens(), + batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)), + in_capturing=True, + expected_decode_len=expected_decode_len, + accept_all_drafts=True, + ) + logger.info( + f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{expected_decode_len}" + ) + else: + for batch_size in sorted(capture_sizes, reverse=True): + self._dummy_run( + num_tokens=self.scheduler_config.max_num_batched_tokens, + batch_size=batch_size, + expected_decode_len=expected_decode_len, + in_capturing=True, + ) + logger.info( + f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}" + ) except RuntimeError as e: if "out of memory" in str(e): raise RuntimeError( @@ -1278,6 +1282,7 @@ def execute_model( num_running_requests: int = None, is_dummy_run: bool = False, in_capturing: bool = False, + accept_all_drafts: bool = False, ) -> Optional[ModelRunnerOutput]: """ The Entrance of model execute. @@ -1291,14 +1296,18 @@ class at the server level, which is too granular for ModelRunner. # 0. set debug level # self._set_debug_level(0x1, model_forward_batch, is_dummy_run) with kv_signal_sender_context_manager(self.pd_disaggregation_mode) as sender: - self.share_inputs["kv_signal_sender"] = sender # 1. Prepare inputs of model and decoder. self._prepare_inputs(is_dummy_run=is_dummy_run) + # 2. Padding inputs for cuda graph + self.padding_cudagraph_inputs() if is_dummy_run: self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph - # 2. Padding inputs for cuda grph - self.padding_cudagraph_inputs() + else: + self.forward_meta.step_use_cudagraph = ( + self.forward_meta.step_use_cudagraph + and self.real_token_num <= self.fd_config.graph_opt_config.max_capture_size + ) num_tokens = self.share_inputs["ids_remove_padding"].shape[0] if not self.parallel_config.enable_expert_parallel and num_tokens <= 0: @@ -1311,13 +1320,11 @@ class at the server level, which is too granular for ModelRunner. self._execute_empty_input(self.forward_meta) return None - # 2. Padding inputs for cuda grph - model_inputs = {} model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"] if self.enable_mm: model_inputs["image_features"] = self.share_inputs["image_features"] - # 3. Execute model + # 3. Execute model_output = self.model( model_inputs, forward_meta=self.forward_meta, @@ -1348,6 +1355,8 @@ class at the server level, which is too granular for ModelRunner. self.sampling_metadata, self.model_config.max_model_len, self.share_inputs, + self.increment_value, + accept_all_drafts=accept_all_drafts, ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( @@ -1444,13 +1453,18 @@ class at the server level, which is too granular for ModelRunner. # 6. Draft model propose if self.speculative_decoding and self.proposer is not None: if self.spec_method == SpecMethod.MTP: - self.proposer.run(full_hidden_states=model_output) + self.proposer.run( + full_hidden_states=model_output, + step_use_cudagraph=self.forward_meta.step_use_cudagraph, + is_dummy_run=is_dummy_run, + ) else: self.proposer.run(share_inputs=self.share_inputs) # 7. Updata 'infer_seed' and step_paddle() - self.share_inputs["infer_seed"].add_(self.infer_seed_increment) - self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED + if not self.speculative_decoding: + self.share_inputs["infer_seed"].add_(self.infer_seed_increment) + self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED if self.speculative_decoding: speculate_schedule_cache( diff --git a/tests/xpu_ci/4cards_cases/run_mtp_cudagraph.py b/tests/xpu_ci/4cards_cases/test_mtp_cudagraph.py similarity index 100% rename from tests/xpu_ci/4cards_cases/run_mtp_cudagraph.py rename to tests/xpu_ci/4cards_cases/test_mtp_cudagraph.py