Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions cpp/tensorrt_llm/kernels/beamSearchKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
#include "tensorrt_llm/kernels/beamSearchKernels.h"

using namespace tensorrt_llm::common;
Expand Down Expand Up @@ -135,21 +136,26 @@ void invokeUpdateCacheIndirection(int* tgtCI, int const* srcCI, BeamHypotheses&
sync_check_cuda_error(stream);
}

__global__ void addCumLogProbs(float* __restrict pStage1LogProbs, float const* __restrict cumLogProbs,
FinishedState const* finished, int const* endIds, float const* diversityRates,
__global__ void addCumLogProbs(float* __restrict pStage1LogProbs, int const* __restrict pStage1Ids,
float const* __restrict cumLogProbs, FinishedState const* finished, int const* endIds, float const* diversityRates,
runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM)
{
int const bid = blockIdx.x; // Index of request in batch
runtime::SizeType32 const slot = batchSlots[bid];
float const diversityRate{diversityRates[slot]};
float* pLocalLogProbs = pStage1LogProbs + bid * nBMIn * nBMOut * 2;
int const* pLocalIds = pStage1Ids + bid * nBMIn * nBMOut * 2;

for (int i = threadIdx.x; i < nBMIn * nBMOut * 2; i += blockDim.x)
{
int const iBMIn = i / (nBMOut * 2);
if (finished[slot * nBMIn + iBMIn].isFinished())
{
pLocalLogProbs[i] += (i == endIds[slot]) ? 1.0f : 0.0f;
// In V2 path, i is a candidate-slot index (0..nBMIn*nBMOut*2-1), NOT a vocab token id.
// Use pStage1Ids to look up the actual token id for the EOS comparison.
bool const isEOS = (pLocalIds[i] == endIds[slot]);
// Keep only the EOS candidate with its proper cumulative score; suppress all others.
pLocalLogProbs[i] = isEOS ? (pLocalLogProbs[i] + cumLogProbs[slot * nBM + iBMIn]) : -FLT_MAX;
}
else
{
Expand All @@ -160,21 +166,27 @@ __global__ void addCumLogProbs(float* __restrict pStage1LogProbs, float const* _
return;
}

__global__ void addCumLogProbs(half* __restrict pStage1LogProbs, float const* __restrict cumLogProbs,
FinishedState const* finished, int const* endIds, float const* diversityRates,
__global__ void addCumLogProbs(half* __restrict pStage1LogProbs, int const* __restrict pStage1Ids,
float const* __restrict cumLogProbs, FinishedState const* finished, int const* endIds, float const* diversityRates,
runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM)
{
int const bid = blockIdx.x; // Index of request in batch
runtime::SizeType32 const slot = batchSlots[bid];
float const diversityRate{diversityRates[slot]};
half* pLocalLogProbs = pStage1LogProbs + bid * nBMIn * nBMOut * 2;
int const* pLocalIds = pStage1Ids + bid * nBMIn * nBMOut * 2;

for (int i = threadIdx.x; i < nBMIn * nBMOut * 2; i += blockDim.x)
{
int const iBMIn = i / (nBMOut * 2);
if (finished[slot * nBMIn + iBMIn].isFinished())
{
pLocalLogProbs[i] += (i == endIds[slot]) ? 1.0f : 0.0f;
// In V2 path, i is a candidate-slot index (0..nBMIn*nBMOut*2-1), NOT a vocab token id.
// Use pStage1Ids to look up the actual token id for the EOS comparison.
bool const isEOS = (pLocalIds[i] == endIds[slot]);
// Keep only the EOS candidate with its proper cumulative score; suppress all others.
pLocalLogProbs[i]
= isEOS ? (half) (float(pLocalLogProbs[i]) + cumLogProbs[slot * nBM + iBMIn]) : (half) -HALF_FLT_MAX;
}
else
{
Expand Down
16 changes: 9 additions & 7 deletions cpp/tensorrt_llm/kernels/beamSearchKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,15 @@ void invokeTopkBeamSearch(T const* logProbs, T const* bias, void* workspace, Bea
void invokeUpdateCacheIndirection(int* tgtCI, int const* srcCI, BeamHypotheses& bh,
runtime::SizeType32 const maxAttentionWindow, runtime::SizeType32 sinkTokenLength, cudaStream_t stream);

__global__ void addCumLogProbs(float* __restrict pStage1LogProbs, float const* __restrict cumLogProbs,
::tensorrt_llm::kernels::FinishedState const* finished, int const* endIds, float const* diversityRates,
runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM);

__global__ void addCumLogProbs(half* __restrict pStage1LogProbs, float const* __restrict cumLogProbs,
::tensorrt_llm::kernels::FinishedState const* finished, int const* endIds, float const* diversityRates,
runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM);
__global__ void addCumLogProbs(float* __restrict pStage1LogProbs, int const* __restrict pStage1Ids,
float const* __restrict cumLogProbs, ::tensorrt_llm::kernels::FinishedState const* finished, int const* endIds,
float const* diversityRates, runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn,
size_t const nBMOut, size_t const nBM);

__global__ void addCumLogProbs(half* __restrict pStage1LogProbs, int const* __restrict pStage1Ids,
float const* __restrict cumLogProbs, ::tensorrt_llm::kernels::FinishedState const* finished, int const* endIds,
float const* diversityRates, runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn,
size_t const nBMOut, size_t const nBM);

__global__ void gatherId(int const* __restrict pStage1Id, int* __restrict pStage2Id, size_t const nBS,
size_t const nBMIn, size_t const nBMOut, size_t const nV);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void beamStage3Kernel(
__shared__ float smemCumLogProbs[PBM];
__shared__ int smemSeqLen[PBM];
__shared__ KVPair smemTopKV[(IS_V2) ? 1 : PBM * 2]; // Just a placeholder in V2 workflow
__shared__ int smemNBeamForNextStep;

if (bh.numBeamsCBA != nullptr)
{
Expand All @@ -217,14 +218,11 @@ __launch_bounds__(BLOCK_SIZE) __global__ void beamStage3Kernel(
// Initialize worst score in the first call
bh.minNormedScoresCBA[slot] = 0.0f; // logProbs is in range (-inf, 0]
}
else if (earlyStopping == 1 && bh.numBeamsCBA[slot] == nBM
|| earlyStopping != 1 && bh.finished[slot * nBM].isFinished())
else if (earlyStopping == 1 && bh.numBeamsCBA[slot] >= nBM || earlyStopping != 1 && bh.batchDones[slot])
{
// Condition of early return:
// 1. In EarlyStopping mode, and we have got enough beams
// 2. In NonEarlyStopping mode, and this batch has been marked as done
// TODO: improve the condition like below
// earlyStopping == 1 && bh.numBeamsCBA[slot] == nBM || earlyStopping != 1 && bh.batchDones[slot]
return;
}
}
Expand Down Expand Up @@ -324,10 +322,14 @@ __launch_bounds__(BLOCK_SIZE) __global__ void beamStage3Kernel(
{
// Condition of this branch:
// This token is end-token and belongs to top nBM range in Beam search mode
int const nSeqLen = bh.sequenceLengths[slot * nBM + i] + 1 - bh.inputLengths[slot * nBM + i];
// Use the actual parent beam index (topId / nV) % nBM, not the candidate rank i,
// to look up the correct sequenceLength and inputLength for length-penalty scoring.
int const parentBeam = (topId / nV) % nBM;
int const nSeqLen
= bh.sequenceLengths[slot * nBM + parentBeam] + 1 - bh.inputLengths[slot * nBM + parentBeam];
float const score = applyLengthPenalty(topLogProb, nSeqLen, lengthPenalty);
int nCBA = bh.numBeamsCBA[slot];
if (nCBA == nBM)
if (nCBA >= nBM)
{
// There are already nBM beams
if (score < bh.minNormedScoresCBA[slot])
Expand Down Expand Up @@ -437,6 +439,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void beamStage3Kernel(
break;
}
}
smemNBeamForNextStep = nBeamForNextStep;
}

// Update bh.batchDones
Expand Down Expand Up @@ -481,23 +484,33 @@ __launch_bounds__(BLOCK_SIZE) __global__ void beamStage3Kernel(
if (tid < nBMOut)
{
int const indexBatchBeam = slot * nBM + tid;
int const step = smemSeqLen[tid];
if (!bh.finished[indexBatchBeam].isFinished())
if (tid < smemNBeamForNextStep)
{
smemSeqLen[tid]++;
// This slot received a valid next-step token from the selection phase.
int const step = smemSeqLen[tid];
if (!bh.finished[indexBatchBeam].isFinished())
{
smemSeqLen[tid]++;
}
int const newId = bh.outputIdsPtr[slot][tid * nMSL + step];
int const newBeamId = (newId / nV) % nBM;
int const newTokenId = newId % nV;
bh.sequenceLengths[indexBatchBeam] = smemSeqLen[newBeamId];

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the new write-back loop, smemSeqLen[] is only incremented for slots with tid < smemNBeamForNextStep; the rest keep their stale pre-step value. But the child beam's length is then read as smemSeqLen[newBeamId], where newBeamId = (newId / nV) % nBM can be any value in [0, nBM) — so it may index a slot that wasn't incremented this step, giving the child a stale (one-short) sequence length that later misindexes outputIds and corrupts length-penalty scoring.

Is newBeamId guaranteed to fall within [0, smemNBeamForNextStep) in the underfilled case?

if (newTokenId == bh.endIds[slot])
{
bh.finished[indexBatchBeam].setFinishedEOS();
}
bh.parentIdsPtr[slot][tid * nMSL + step] = newBeamId;
bh.outputIdsPtr[slot][tid * nMSL + step] = newTokenId;
}
int const newId = bh.outputIdsPtr[slot][tid * nMSL + step];
int const newBeamId = (newId / nV) % nBM;
int const newTokenId = newId % nV;
bh.sequenceLengths[indexBatchBeam] = smemSeqLen[newBeamId];
if (newTokenId == bh.endIds[slot])
else
{
bh.finished[indexBatchBeam].setFinishedEOS();
// No valid next-step token for this slot: all top candidates went to CBA.
// Mark as finished so downstream stages (cache indirection, next decode) skip it.
bh.finished[indexBatchBeam].setFinished();
}
bh.parentIdsPtr[slot][tid * nMSL + step] = newBeamId;
bh.outputIdsPtr[slot][tid * nMSL + step] = newTokenId;

if ((earlyStopping == 1) && (bh.numBeamsCBA != nullptr && bh.numBeamsCBA[slot] == nBM)
if ((earlyStopping == 1) && (bh.numBeamsCBA != nullptr && bh.numBeamsCBA[slot] >= nBM)
|| (earlyStopping != 1) && bh.batchDones[slot])
{
bh.batchDones[slot] = true;
Expand Down Expand Up @@ -631,7 +644,7 @@ void beamSearchKernelLauncher(
sync_check_cuda_error(stream);

int nThread = min(roundUp(nBMIn * nBMOut * 2, 32), 1024);
addCumLogProbs<<<nBS, nThread, 0, stream>>>(pStage1LogProbs, bh.cumLogProbs, bh.finished, bh.endIds,
addCumLogProbs<<<nBS, nThread, 0, stream>>>(pStage1LogProbs, pStage1Ids, bh.cumLogProbs, bh.finished, bh.endIds,
bh.diversityRates, bh.batchSlots, nBS, nBMIn, nBMOut, nBM);
sync_check_cuda_error(stream);

Expand Down
Loading