From de64005d37cc00ed53a72846eb7acbc83f60f107 Mon Sep 17 00:00:00 2001 From: wili Date: Thu, 25 Jun 2026 01:25:48 -0700 Subject: [PATCH] [https://nvbugs/6242591][fix] Fix bugs in Beam Search kernels Signed-off-by: wili-65535 <12345678+wili@users.noreply.github.com> --- cpp/tensorrt_llm/kernels/beamSearchKernels.cu | 24 ++++++--- cpp/tensorrt_llm/kernels/beamSearchKernels.h | 16 +++--- .../beamSearchKernelsTemplate.h | 51 ++++++++++++------- 3 files changed, 59 insertions(+), 32 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/beamSearchKernels.cu b/cpp/tensorrt_llm/kernels/beamSearchKernels.cu index 005a1539168e..c00b9de31616 100644 --- a/cpp/tensorrt_llm/kernels/beamSearchKernels.cu +++ b/cpp/tensorrt_llm/kernels/beamSearchKernels.cu @@ -16,6 +16,7 @@ #include "tensorrt_llm/common/config.h" #include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/reduceKernelUtils.cuh" #include "tensorrt_llm/kernels/beamSearchKernels.h" using namespace tensorrt_llm::common; @@ -135,21 +136,26 @@ void invokeUpdateCacheIndirection(int* tgtCI, int const* srcCI, BeamHypotheses& sync_check_cuda_error(stream); } -__global__ void addCumLogProbs(float* __restrict pStage1LogProbs, float const* __restrict cumLogProbs, - FinishedState const* finished, int const* endIds, float const* diversityRates, +__global__ void addCumLogProbs(float* __restrict pStage1LogProbs, int const* __restrict pStage1Ids, + float const* __restrict cumLogProbs, FinishedState const* finished, int const* endIds, float const* diversityRates, runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM) { int const bid = blockIdx.x; // Index of request in batch runtime::SizeType32 const slot = batchSlots[bid]; float const diversityRate{diversityRates[slot]}; float* pLocalLogProbs = pStage1LogProbs + bid * nBMIn * nBMOut * 2; + int const* pLocalIds = pStage1Ids + bid * nBMIn * nBMOut * 2; for (int i = threadIdx.x; i < nBMIn * nBMOut * 2; i += blockDim.x) { int const iBMIn = i / (nBMOut * 2); if (finished[slot * nBMIn + iBMIn].isFinished()) { - pLocalLogProbs[i] += (i == endIds[slot]) ? 1.0f : 0.0f; + // In V2 path, i is a candidate-slot index (0..nBMIn*nBMOut*2-1), NOT a vocab token id. + // Use pStage1Ids to look up the actual token id for the EOS comparison. + bool const isEOS = (pLocalIds[i] == endIds[slot]); + // Keep only the EOS candidate with its proper cumulative score; suppress all others. + pLocalLogProbs[i] = isEOS ? (pLocalLogProbs[i] + cumLogProbs[slot * nBM + iBMIn]) : -FLT_MAX; } else { @@ -160,21 +166,27 @@ __global__ void addCumLogProbs(float* __restrict pStage1LogProbs, float const* _ return; } -__global__ void addCumLogProbs(half* __restrict pStage1LogProbs, float const* __restrict cumLogProbs, - FinishedState const* finished, int const* endIds, float const* diversityRates, +__global__ void addCumLogProbs(half* __restrict pStage1LogProbs, int const* __restrict pStage1Ids, + float const* __restrict cumLogProbs, FinishedState const* finished, int const* endIds, float const* diversityRates, runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM) { int const bid = blockIdx.x; // Index of request in batch runtime::SizeType32 const slot = batchSlots[bid]; float const diversityRate{diversityRates[slot]}; half* pLocalLogProbs = pStage1LogProbs + bid * nBMIn * nBMOut * 2; + int const* pLocalIds = pStage1Ids + bid * nBMIn * nBMOut * 2; for (int i = threadIdx.x; i < nBMIn * nBMOut * 2; i += blockDim.x) { int const iBMIn = i / (nBMOut * 2); if (finished[slot * nBMIn + iBMIn].isFinished()) { - pLocalLogProbs[i] += (i == endIds[slot]) ? 1.0f : 0.0f; + // In V2 path, i is a candidate-slot index (0..nBMIn*nBMOut*2-1), NOT a vocab token id. + // Use pStage1Ids to look up the actual token id for the EOS comparison. + bool const isEOS = (pLocalIds[i] == endIds[slot]); + // Keep only the EOS candidate with its proper cumulative score; suppress all others. + pLocalLogProbs[i] + = isEOS ? (half) (float(pLocalLogProbs[i]) + cumLogProbs[slot * nBM + iBMIn]) : (half) -HALF_FLT_MAX; } else { diff --git a/cpp/tensorrt_llm/kernels/beamSearchKernels.h b/cpp/tensorrt_llm/kernels/beamSearchKernels.h index d8a9266e9406..345e4659c941 100644 --- a/cpp/tensorrt_llm/kernels/beamSearchKernels.h +++ b/cpp/tensorrt_llm/kernels/beamSearchKernels.h @@ -131,13 +131,15 @@ void invokeTopkBeamSearch(T const* logProbs, T const* bias, void* workspace, Bea void invokeUpdateCacheIndirection(int* tgtCI, int const* srcCI, BeamHypotheses& bh, runtime::SizeType32 const maxAttentionWindow, runtime::SizeType32 sinkTokenLength, cudaStream_t stream); -__global__ void addCumLogProbs(float* __restrict pStage1LogProbs, float const* __restrict cumLogProbs, - ::tensorrt_llm::kernels::FinishedState const* finished, int const* endIds, float const* diversityRates, - runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM); - -__global__ void addCumLogProbs(half* __restrict pStage1LogProbs, float const* __restrict cumLogProbs, - ::tensorrt_llm::kernels::FinishedState const* finished, int const* endIds, float const* diversityRates, - runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM); +__global__ void addCumLogProbs(float* __restrict pStage1LogProbs, int const* __restrict pStage1Ids, + float const* __restrict cumLogProbs, ::tensorrt_llm::kernels::FinishedState const* finished, int const* endIds, + float const* diversityRates, runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, + size_t const nBMOut, size_t const nBM); + +__global__ void addCumLogProbs(half* __restrict pStage1LogProbs, int const* __restrict pStage1Ids, + float const* __restrict cumLogProbs, ::tensorrt_llm::kernels::FinishedState const* finished, int const* endIds, + float const* diversityRates, runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, + size_t const nBMOut, size_t const nBM); __global__ void gatherId(int const* __restrict pStage1Id, int* __restrict pStage2Id, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nV); diff --git a/cpp/tensorrt_llm/kernels/beamSearchKernels/beamSearchKernelsTemplate.h b/cpp/tensorrt_llm/kernels/beamSearchKernels/beamSearchKernelsTemplate.h index eb0d9e072997..df9228961c53 100644 --- a/cpp/tensorrt_llm/kernels/beamSearchKernels/beamSearchKernelsTemplate.h +++ b/cpp/tensorrt_llm/kernels/beamSearchKernels/beamSearchKernelsTemplate.h @@ -208,6 +208,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void beamStage3Kernel( __shared__ float smemCumLogProbs[PBM]; __shared__ int smemSeqLen[PBM]; __shared__ KVPair smemTopKV[(IS_V2) ? 1 : PBM * 2]; // Just a placeholder in V2 workflow + __shared__ int smemNBeamForNextStep; if (bh.numBeamsCBA != nullptr) { @@ -217,14 +218,11 @@ __launch_bounds__(BLOCK_SIZE) __global__ void beamStage3Kernel( // Initialize worst score in the first call bh.minNormedScoresCBA[slot] = 0.0f; // logProbs is in range (-inf, 0] } - else if (earlyStopping == 1 && bh.numBeamsCBA[slot] == nBM - || earlyStopping != 1 && bh.finished[slot * nBM].isFinished()) + else if (earlyStopping == 1 && bh.numBeamsCBA[slot] >= nBM || earlyStopping != 1 && bh.batchDones[slot]) { // Condition of early return: // 1. In EarlyStopping mode, and we have got enough beams // 2. In NonEarlyStopping mode, and this batch has been marked as done - // TODO: improve the condition like below - // earlyStopping == 1 && bh.numBeamsCBA[slot] == nBM || earlyStopping != 1 && bh.batchDones[slot] return; } } @@ -324,10 +322,14 @@ __launch_bounds__(BLOCK_SIZE) __global__ void beamStage3Kernel( { // Condition of this branch: // This token is end-token and belongs to top nBM range in Beam search mode - int const nSeqLen = bh.sequenceLengths[slot * nBM + i] + 1 - bh.inputLengths[slot * nBM + i]; + // Use the actual parent beam index (topId / nV) % nBM, not the candidate rank i, + // to look up the correct sequenceLength and inputLength for length-penalty scoring. + int const parentBeam = (topId / nV) % nBM; + int const nSeqLen + = bh.sequenceLengths[slot * nBM + parentBeam] + 1 - bh.inputLengths[slot * nBM + parentBeam]; float const score = applyLengthPenalty(topLogProb, nSeqLen, lengthPenalty); int nCBA = bh.numBeamsCBA[slot]; - if (nCBA == nBM) + if (nCBA >= nBM) { // There are already nBM beams if (score < bh.minNormedScoresCBA[slot]) @@ -437,6 +439,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void beamStage3Kernel( break; } } + smemNBeamForNextStep = nBeamForNextStep; } // Update bh.batchDones @@ -481,23 +484,33 @@ __launch_bounds__(BLOCK_SIZE) __global__ void beamStage3Kernel( if (tid < nBMOut) { int const indexBatchBeam = slot * nBM + tid; - int const step = smemSeqLen[tid]; - if (!bh.finished[indexBatchBeam].isFinished()) + if (tid < smemNBeamForNextStep) { - smemSeqLen[tid]++; + // This slot received a valid next-step token from the selection phase. + int const step = smemSeqLen[tid]; + if (!bh.finished[indexBatchBeam].isFinished()) + { + smemSeqLen[tid]++; + } + int const newId = bh.outputIdsPtr[slot][tid * nMSL + step]; + int const newBeamId = (newId / nV) % nBM; + int const newTokenId = newId % nV; + bh.sequenceLengths[indexBatchBeam] = smemSeqLen[newBeamId]; + if (newTokenId == bh.endIds[slot]) + { + bh.finished[indexBatchBeam].setFinishedEOS(); + } + bh.parentIdsPtr[slot][tid * nMSL + step] = newBeamId; + bh.outputIdsPtr[slot][tid * nMSL + step] = newTokenId; } - int const newId = bh.outputIdsPtr[slot][tid * nMSL + step]; - int const newBeamId = (newId / nV) % nBM; - int const newTokenId = newId % nV; - bh.sequenceLengths[indexBatchBeam] = smemSeqLen[newBeamId]; - if (newTokenId == bh.endIds[slot]) + else { - bh.finished[indexBatchBeam].setFinishedEOS(); + // No valid next-step token for this slot: all top candidates went to CBA. + // Mark as finished so downstream stages (cache indirection, next decode) skip it. + bh.finished[indexBatchBeam].setFinished(); } - bh.parentIdsPtr[slot][tid * nMSL + step] = newBeamId; - bh.outputIdsPtr[slot][tid * nMSL + step] = newTokenId; - if ((earlyStopping == 1) && (bh.numBeamsCBA != nullptr && bh.numBeamsCBA[slot] == nBM) + if ((earlyStopping == 1) && (bh.numBeamsCBA != nullptr && bh.numBeamsCBA[slot] >= nBM) || (earlyStopping != 1) && bh.batchDones[slot]) { bh.batchDones[slot] = true; @@ -631,7 +644,7 @@ void beamSearchKernelLauncher( sync_check_cuda_error(stream); int nThread = min(roundUp(nBMIn * nBMOut * 2, 32), 1024); - addCumLogProbs<<>>(pStage1LogProbs, bh.cumLogProbs, bh.finished, bh.endIds, + addCumLogProbs<<>>(pStage1LogProbs, pStage1Ids, bh.cumLogProbs, bh.finished, bh.endIds, bh.diversityRates, bh.batchSlots, nBS, nBMIn, nBMOut, nBM); sync_check_cuda_error(stream);