Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions csrc/engine/compiler/chunk_prefill_compiler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#include "chunk_prefill_compiler.hpp"
#include "infinicore/context/context.hpp"


namespace {
inline void set_zeros(infinicore::Tensor &tensor) {
std::vector<uint8_t> zeros(tensor->nbytes(), 0);
infinicore::context::memcpyH2D(tensor->data(), zeros.data(), tensor->nbytes(), false);
}
} // namespace

namespace infinilm::engine {

ChunkPrefillCompiler::ChunkPrefillCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
: GraphCompiler(model, barrier) {
// Enumerate chunk sizes for chunk-prefill
for (size_t cs : {64, 128, 256, 512, 1024, 2048}) {
chunk_sizes_.push_back(cs);
}
// Enumerate batch sizes for prefill (typically smaller than decode)
for (size_t b = 1; b < 32; b++) {
prefill_batch_sizes_.push_back(b);
}
for (size_t b = 32; b < 64; b += 8) {
prefill_batch_sizes_.push_back(b);
}
for (size_t b = 64; b < 128; b += 16) {
prefill_batch_sizes_.push_back(b);
}
for (size_t b = 128; b < 256; b += 32) {
prefill_batch_sizes_.push_back(b);
}
for (size_t b = 256; b <= 512; b += 64) {
prefill_batch_sizes_.push_back(b);
}
}

void ChunkPrefillCompiler::compile() {
if (model_->get_cache_config() != nullptr &&
dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {

const auto *paged_config =
dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config());
size_t nblocks = paged_config->num_blocks();

compiled_map_prefill_.clear();

// Max total tokens to avoid OOM during graph recording
constexpr size_t MAX_TOTAL_TOKENS = 4096;

// Pre-allocate a shared block_tables_holder for the largest (batch_size) we'll use
size_t max_batch = *std::max_element(prefill_batch_sizes_.begin(), prefill_batch_sizes_.end());
size_t block_per_req = nblocks / max_batch;
block_tables_holder_ = infinicore::Tensor::empty(
{nblocks}, infinicore::DataType::I32, infinicore::context::getDevice());
set_zeros(block_tables_holder_);

for (size_t b : prefill_batch_sizes_) {
for (size_t cs : chunk_sizes_) {
size_t total_tokens = b * cs;
if (total_tokens > MAX_TOTAL_TOKENS) {
continue;
}

size_t bpr = nblocks / b; // block_per_req for this batch size

InfinilmModel::Input input;

// input_ids: [1, total_tokens] — all tokens for this batch packed together
input.input_ids = infinicore::Tensor::empty(
{1, total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(input.input_ids.value());

// position_ids: [total_tokens]
input.position_ids = infinicore::Tensor::empty(
{total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(input.position_ids.value());

// total_sequence_lengths: [b], set to cs (first-chunk scenario)
input.total_sequence_lengths = infinicore::Tensor::empty(
{b}, infinicore::DataType::I32, infinicore::context::getDevice());
{
std::vector<int32_t> tsl(b, static_cast<int32_t>(cs));
infinicore::context::memcpyH2D(
input.total_sequence_lengths.value()->data(),
tsl.data(), b * sizeof(int32_t), false);
}

// input_offsets: [b+1], stride = cs
input.input_offsets = infinicore::Tensor::empty(
{b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
{
std::vector<int32_t> offsets(b + 1);
for (size_t i = 0; i <= b; i++) {
offsets[i] = static_cast<int32_t>(i * cs);
}
infinicore::context::memcpyH2D(
input.input_offsets.value()->data(),
offsets.data(), (b + 1) * sizeof(int32_t), false);
}

// cu_seqlens: [b+1], same layout as input_offsets for prefill
input.cu_seqlens = infinicore::Tensor::empty(
{b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
{
std::vector<int32_t> cu(b + 1);
for (size_t i = 0; i <= b; i++) {
cu[i] = static_cast<int32_t>(i * cs);
}
infinicore::context::memcpyH2D(
input.cu_seqlens.value()->data(),
cu.data(), (b + 1) * sizeof(int32_t), false);
}

// block_tables: view into the shared holder [b, bpr]
input.block_tables = block_tables_holder_->as_strided(
{b, bpr}, {(ptrdiff_t)bpr, 1});

// slot_mapping: [total_tokens]
input.slot_mapping = infinicore::Tensor::empty(
{total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(input.slot_mapping.value());

barrier_->wait();
infinicore::context::startGraphRecording();
auto output = model_->forward(input);
auto graph = infinicore::context::stopGraphRecording();
barrier_->wait();

auto shared_output = std::shared_ptr<InfinilmModel::Output>(
new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)});

compiled_map_prefill_[std::make_tuple(b, cs)] =
CompiledResult{std::move(input), std::make_tuple(graph, shared_output)};
}
}
}
}

ChunkPrefillCompiler::Compiled ChunkPrefillCompiler::get_compiled(const InfinilmModel::Input &input) {
if (model_->get_cache_config() == nullptr ||
!dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
return {nullptr, nullptr};
}

if (!input.block_tables.has_value() || !input.input_ids.has_value()) {
return {nullptr, nullptr};
}

size_t batch_size = input.block_tables.value()->size(0);
size_t block_per_req = input.block_tables.value()->size(1);
size_t total_tokens = input.input_ids.value()->size(1);

// Prefill: total_tokens is a multiple of batch_size, and chunk_size > 1
if (total_tokens == 0 || total_tokens % batch_size != 0) {
return {nullptr, nullptr};
}
size_t chunk_size = total_tokens / batch_size;
if (chunk_size <= 1) {
// Single-token case belongs to decode
return {nullptr, nullptr};
}

auto result = compiled_map_prefill_.find(std::make_tuple(batch_size, chunk_size));
if (result == compiled_map_prefill_.end()) {
return {nullptr, nullptr};
}

auto &graph_input = result->second.input;

graph_input.input_ids.value()->copy_from(input.input_ids.value());
graph_input.position_ids.value()->copy_from(input.position_ids.value());
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
graph_input.input_offsets.value()->copy_from(input.input_offsets.value());
graph_input.cu_seqlens.value()->copy_from(input.cu_seqlens.value());
graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value());
graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value());

auto graph = std::get<0>(result->second.compiled);
auto shared_output = std::shared_ptr<InfinilmModel::Output>(
new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()});

return std::make_tuple(graph, shared_output);
}

} // namespace infinilm::engine
42 changes: 42 additions & 0 deletions csrc/engine/compiler/chunk_prefill_compiler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

#include "graph_compiler.hpp"

#include <unordered_map>

namespace infinilm::engine {
class ChunkPrefillCompiler : public GraphCompiler {
public:
ChunkPrefillCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);

void compile() override;

Compiled get_compiled(const InfinilmModel::Input &input) override;

private:
struct TupleHash {
size_t operator()(const std::tuple<size_t, size_t> &t) const noexcept {
auto h1 = std::hash<size_t>{}(std::get<0>(t));
auto h2 = std::hash<size_t>{}(std::get<1>(t));
return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2));
}
};

std::vector<size_t> chunk_sizes_;
std::vector<size_t> prefill_batch_sizes_;

infinicore::Tensor block_tables_holder_;

struct CompiledResult {
InfinilmModel::Input input;
Compiled compiled;
};

// Key: (batch_size, chunk_size)
std::unordered_map<
std::tuple<size_t, size_t>,
CompiledResult,
TupleHash>
compiled_map_prefill_;
};
} // namespace infinilm::engine
12 changes: 11 additions & 1 deletion csrc/engine/compiler/general_compiler.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
#include "general_compiler.hpp"

namespace infinilm::engine {
GeneralCompiler::GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier) : GraphCompiler(model, barrier) {
GeneralCompiler::GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier, bool enable_chunk_prefill_graph)
: GraphCompiler(model, barrier), enable_chunk_prefill_graph_(enable_chunk_prefill_graph) {
static_batching_compiler_ = std::make_unique<StaticBatchingCompiler>(model_, barrier);
chunk_prefill_compiler_ = std::make_unique<ChunkPrefillCompiler>(model_, barrier);
paged_compiler_ = std::make_unique<PagedCompiler>(model_, barrier);
}

void GeneralCompiler::compile() {
static_batching_compiler_->compile();
if (enable_chunk_prefill_graph_) {
chunk_prefill_compiler_->compile();
}
paged_compiler_->compile();
}

Expand All @@ -19,6 +24,11 @@ GeneralCompiler::Compiled GeneralCompiler::get_compiled(const InfinilmModel::Inp
if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) {
return result;
}
// chunk-prefill must be checked before decode (decode would also match if chunk_size==1)
result = chunk_prefill_compiler_.get()->get_compiled(input);
if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) {
return result;
}
result = paged_compiler_.get()->get_compiled(input);
return result;
}
Expand Down
5 changes: 4 additions & 1 deletion csrc/engine/compiler/general_compiler.hpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#pragma once

#include "chunk_prefill_compiler.hpp"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件在哪里

#include "paged_compiler.hpp"
#include "static_batching_compiler.hpp"

namespace infinilm::engine {
class GeneralCompiler : public GraphCompiler {
public:
GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);
GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier, bool enable_chunk_prefill_graph = false);

void compile() override;

Expand All @@ -15,5 +16,7 @@ class GeneralCompiler : public GraphCompiler {
private:
std::unique_ptr<StaticBatchingCompiler> static_batching_compiler_;
std::unique_ptr<PagedCompiler> paged_compiler_;
std::unique_ptr<ChunkPrefillCompiler> chunk_prefill_compiler_;
bool enable_chunk_prefill_graph_;
};
} // namespace infinilm::engine
4 changes: 4 additions & 0 deletions csrc/engine/infer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ InferEngine::InferEngine(
infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config,
bool enable_graph_compiling,
bool enable_chunk_prefill_graph,
backends::AttentionBackend attention_backend) // Changed parameter
: communication_group_(distributed_config, device_type),
legacy_model_config_(config),
Expand All @@ -43,6 +44,7 @@ InferEngine::InferEngine(
cache_config_ != nullptr ? cache_config_.get() : nullptr,
barrier_.get(),
enable_graph_compiling,
enable_chunk_prefill_graph,
attention_backend_));
}

Expand All @@ -56,6 +58,7 @@ InferEngine::InferEngine(
infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config,
bool enable_graph_compiling,
bool enable_chunk_prefill_graph,
backends::AttentionBackend attention_backend,
std::optional<infinicore::DataType> kv_cache_dtype) // Changed parameter
: communication_group_(distributed_config, device_type), attention_backend_(attention_backend) {
Expand All @@ -82,6 +85,7 @@ InferEngine::InferEngine(
cache_config_ != nullptr ? cache_config_.get() : nullptr,
barrier_.get(),
enable_graph_compiling,
enable_chunk_prefill_graph,
attention_backend_));
}
// Compile the model on all workers
Expand Down
2 changes: 2 additions & 0 deletions csrc/engine/infer_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class InferEngine {
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false,
bool enable_chunk_prefill_graph = false,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);

InferEngine(
Expand All @@ -47,6 +48,7 @@ class InferEngine {
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false,
bool enable_chunk_prefill_graph = false,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default,
std::optional<infinicore::DataType> kv_cache_dtype = std::nullopt);

Expand Down
6 changes: 5 additions & 1 deletion csrc/engine/rank_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config,
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling,
bool enable_chunk_prefill_graph,
backends::AttentionBackend attention_backend)
: legacy_model_config_(model_config),
rank_info_(rank_info),
attention_backend_(attention_backend),
enable_graph_compiling_(enable_graph_compiling),
enable_chunk_prefill_graph_(enable_chunk_prefill_graph),
job_cmd_(Command::INIT),
has_job_(false),
job_done_(false),
Expand All @@ -56,12 +58,14 @@ RankWorker::RankWorker(
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling,
bool enable_chunk_prefill_graph,
backends::AttentionBackend attention_backend)
: infinilm_config_(infinilm_config),
model_config_(infinilm_config->model_config),
rank_info_(rank_info),
attention_backend_(attention_backend),
enable_graph_compiling_(enable_graph_compiling),
enable_chunk_prefill_graph_(enable_chunk_prefill_graph),
job_cmd_(Command::INIT),
has_job_(false),
job_done_(false),
Expand Down Expand Up @@ -303,7 +307,7 @@ void RankWorker::thread_loop() {
throw std::runtime_error("Failed to create model");
}
if (enable_graph_compiling_) {
compiler_ = std::make_unique<GeneralCompiler>(model_, barrier_);
compiler_ = std::make_unique<GeneralCompiler>(model_, barrier_, enable_chunk_prefill_graph_);
}

init_done_ = true;
Expand Down
3 changes: 3 additions & 0 deletions csrc/engine/rank_worker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,15 @@ class RankWorker {
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling,
bool enable_chunk_prefill_graph,
backends::AttentionBackend attention_backend);

RankWorker(std::shared_ptr<infinilm::global_state::InfinilmConfig> infinilm_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling,
bool enable_chunk_prefill_graph,
backends::AttentionBackend attention_backend);

// Submit a parameter load job and wait until the load completes on the worker thread.
Expand Down Expand Up @@ -131,6 +133,7 @@ class RankWorker {

// Graph Compiling
bool enable_graph_compiling_;
bool enable_chunk_prefill_graph_;
std::unique_ptr<GraphCompiler> compiler_;

// Command for the pending job (protected by mutex_)
Expand Down
Loading