From f2c8bab8531da961de09e4bf808695a4a269bb75 Mon Sep 17 00:00:00 2001 From: huidesheng <1832140001@qq.com> Date: Tue, 12 May 2026 17:45:29 +0800 Subject: [PATCH 1/2] add chunkprefill and prefill cuda graph --- csrc/engine/compiler/general_compiler.cpp | 12 ++- csrc/engine/compiler/general_compiler.hpp | 5 +- csrc/engine/infer_engine.cpp | 4 + csrc/engine/infer_engine.hpp | 2 + csrc/engine/rank_worker.cpp | 6 +- csrc/engine/rank_worker.hpp | 3 + csrc/pybind11/engine/engine.hpp | 6 ++ python/infinilm/base_config.py | 4 + python/infinilm/infer_engine.py | 2 + python/infinilm/llm/llm.py | 44 ++++++++++- python/infinilm/llm/request.py | 17 +++++ python/infinilm/llm/scheduler.py | 41 +++++++++- .../processors/basic_llm_processor.py | 40 +++++++--- python/infinilm/server/inference_server.py | 14 ++++ scripts/infer_task.py | 21 ++++++ scripts/launch_server.py | 75 +++++++++++++++---- 16 files changed, 266 insertions(+), 30 deletions(-) diff --git a/csrc/engine/compiler/general_compiler.cpp b/csrc/engine/compiler/general_compiler.cpp index 84ee670d..36c6420f 100644 --- a/csrc/engine/compiler/general_compiler.cpp +++ b/csrc/engine/compiler/general_compiler.cpp @@ -1,13 +1,18 @@ #include "general_compiler.hpp" namespace infinilm::engine { -GeneralCompiler::GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier) : GraphCompiler(model, barrier) { +GeneralCompiler::GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier, bool enable_chunk_prefill_graph) + : GraphCompiler(model, barrier), enable_chunk_prefill_graph_(enable_chunk_prefill_graph) { static_batching_compiler_ = std::make_unique(model_, barrier); + chunk_prefill_compiler_ = std::make_unique(model_, barrier); paged_compiler_ = std::make_unique(model_, barrier); } void GeneralCompiler::compile() { static_batching_compiler_->compile(); + if (enable_chunk_prefill_graph_) { + chunk_prefill_compiler_->compile(); + } paged_compiler_->compile(); } @@ -19,6 +24,11 @@ GeneralCompiler::Compiled GeneralCompiler::get_compiled(const InfinilmModel::Inp if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) { return result; } + // chunk-prefill must be checked before decode (decode would also match if chunk_size==1) + result = chunk_prefill_compiler_.get()->get_compiled(input); + if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) { + return result; + } result = paged_compiler_.get()->get_compiled(input); return result; } diff --git a/csrc/engine/compiler/general_compiler.hpp b/csrc/engine/compiler/general_compiler.hpp index e8b84b5d..3edbcea0 100644 --- a/csrc/engine/compiler/general_compiler.hpp +++ b/csrc/engine/compiler/general_compiler.hpp @@ -1,12 +1,13 @@ #pragma once +#include "chunk_prefill_compiler.hpp" #include "paged_compiler.hpp" #include "static_batching_compiler.hpp" namespace infinilm::engine { class GeneralCompiler : public GraphCompiler { public: - GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier); + GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier, bool enable_chunk_prefill_graph = false); void compile() override; @@ -15,5 +16,7 @@ class GeneralCompiler : public GraphCompiler { private: std::unique_ptr static_batching_compiler_; std::unique_ptr paged_compiler_; + std::unique_ptr chunk_prefill_compiler_; + bool enable_chunk_prefill_graph_; }; } // namespace infinilm::engine diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index db0dfdd4..5b6ea143 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -25,6 +25,7 @@ InferEngine::InferEngine( infinicore::Device::Type device_type, const cache::CacheConfig *cache_config, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend) // Changed parameter : communication_group_(distributed_config, device_type), legacy_model_config_(config), @@ -43,6 +44,7 @@ InferEngine::InferEngine( cache_config_ != nullptr ? cache_config_.get() : nullptr, barrier_.get(), enable_graph_compiling, + enable_chunk_prefill_graph, attention_backend_)); } @@ -56,6 +58,7 @@ InferEngine::InferEngine( infinicore::Device::Type device_type, const cache::CacheConfig *cache_config, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend, std::optional kv_cache_dtype) // Changed parameter : communication_group_(distributed_config, device_type), attention_backend_(attention_backend) { @@ -82,6 +85,7 @@ InferEngine::InferEngine( cache_config_ != nullptr ? cache_config_.get() : nullptr, barrier_.get(), enable_graph_compiling, + enable_chunk_prefill_graph, attention_backend_)); } // Compile the model on all workers diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index e36ec369..153600c4 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -39,6 +39,7 @@ class InferEngine { infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), const cache::CacheConfig *cache_config = nullptr, bool enable_graph_compiling = false, + bool enable_chunk_prefill_graph = false, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); InferEngine( @@ -47,6 +48,7 @@ class InferEngine { infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), const cache::CacheConfig *cache_config = nullptr, bool enable_graph_compiling = false, + bool enable_chunk_prefill_graph = false, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default, std::optional kv_cache_dtype = std::nullopt); diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 8a94c441..e607c569 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -27,11 +27,13 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config, const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend) : legacy_model_config_(model_config), rank_info_(rank_info), attention_backend_(attention_backend), enable_graph_compiling_(enable_graph_compiling), + enable_chunk_prefill_graph_(enable_chunk_prefill_graph), job_cmd_(Command::INIT), has_job_(false), job_done_(false), @@ -56,12 +58,14 @@ RankWorker::RankWorker( const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend) : infinilm_config_(infinilm_config), model_config_(infinilm_config->model_config), rank_info_(rank_info), attention_backend_(attention_backend), enable_graph_compiling_(enable_graph_compiling), + enable_chunk_prefill_graph_(enable_chunk_prefill_graph), job_cmd_(Command::INIT), has_job_(false), job_done_(false), @@ -303,7 +307,7 @@ void RankWorker::thread_loop() { throw std::runtime_error("Failed to create model"); } if (enable_graph_compiling_) { - compiler_ = std::make_unique(model_, barrier_); + compiler_ = std::make_unique(model_, barrier_, enable_chunk_prefill_graph_); } init_done_ = true; diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index f6adcf47..b045adf6 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -75,6 +75,7 @@ class RankWorker { const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend); RankWorker(std::shared_ptr infinilm_config, @@ -82,6 +83,7 @@ class RankWorker { const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend); // Submit a parameter load job and wait until the load completes on the worker thread. @@ -131,6 +133,7 @@ class RankWorker { // Graph Compiling bool enable_graph_compiling_; + bool enable_chunk_prefill_graph_; std::unique_ptr compiler_; // Command for the pending job (protected by mutex_) diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 2741c9cd..a479f66b 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -37,6 +37,7 @@ inline void bind_infer_engine(py::module &m) { infinicore::Device::Type dev, std::shared_ptr cache_cfg, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, const std::string &attention_backend) { return std::make_shared( cfg, @@ -44,6 +45,7 @@ inline void bind_infer_engine(py::module &m) { dev, cache_cfg ? cache_cfg.get() : nullptr, enable_graph_compiling, + enable_chunk_prefill_graph, infinilm::backends::parse_attention_backend(attention_backend)); }), py::arg("config"), @@ -51,6 +53,7 @@ inline void bind_infer_engine(py::module &m) { py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("cache_config") = py::none(), py::arg("enable_graph_compiling") = false, + py::arg("enable_chunk_prefill_graph") = false, py::arg("attention_backend") = "default") .def("load_param", &InferEngine::load_param, py::arg("name"), py::arg("param"), @@ -81,6 +84,7 @@ inline void bind_infer_engine(py::module &m) { infinicore::Device::Type dev, std::shared_ptr cache_cfg, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, const std::string &attention_backend, std::optional kv_cache_dtype) { return std::make_shared( @@ -89,6 +93,7 @@ inline void bind_infer_engine(py::module &m) { dev, cache_cfg ? cache_cfg.get() : nullptr, enable_graph_compiling, + enable_chunk_prefill_graph, infinilm::backends::parse_attention_backend(attention_backend), kv_cache_dtype); }), @@ -97,6 +102,7 @@ inline void bind_infer_engine(py::module &m) { py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("cache_config") = py::none(), py::arg("enable_graph_compiling") = false, + py::arg("enable_chunk_prefill_graph") = false, py::arg("attention_backend") = "default", py::arg("kv_cache_dtype") = py::none()) .def("load_param", &InferEngine::load_param, diff --git a/python/infinilm/base_config.py b/python/infinilm/base_config.py index aab5dd45..d7c32568 100644 --- a/python/infinilm/base_config.py +++ b/python/infinilm/base_config.py @@ -61,6 +61,8 @@ def __init__(self): self.attn = self.args.attn self.enable_graph = self.args.enable_graph + self.enable_chunk_prefill_graph = self.args.enable_chunk_prefill_graph + self.chunk_size = self.args.chunk_size self.enable_paged_attn = self.args.enable_paged_attn self.num_blocks = self.args.num_blocks self.block_size = self.args.block_size @@ -122,6 +124,8 @@ def _add_common_args(self): choices=["default", "paged-attn", "flash-attn"], ) self.parser.add_argument("--enable-graph", action="store_true") + self.parser.add_argument("--enable-chunk-prefill-graph", action="store_true", help="enable chunk-prefill graph compiling") + self.parser.add_argument("--chunk-size", type=int, default=512, help="tokens per chunked-prefill slice (0 to disable)") self.parser.add_argument( "--enable-paged-attn", action="store_true", diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 13bb18a1..2477bbc6 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -45,6 +45,7 @@ def __init__( distributed_config=DistConfig(1), cache_config=None, enable_graph_compiling=False, + enable_chunk_prefill_graph=False, attention_backend="default", kv_cache_dtype=None, ): @@ -60,6 +61,7 @@ def __init__( device._underlying.type, cache_config, enable_graph_compiling, + enable_chunk_prefill_graph, attention_backend, ( parse_dtype(kv_cache_dtype)._underlying diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index cba3af83..90de3edc 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -72,6 +72,8 @@ class EngineConfig: top_p: float = 0.8 top_k: int = 1 enable_graph: bool = False + enable_chunk_prefill_graph: bool = False + chunk_size: int = 0 attn_backend: str = "default" skip_load: bool = False @@ -91,6 +93,7 @@ def __init__(self, config: EngineConfig): device=self.device, distributed_config=DistConfig(config.tensor_parallel_size), enable_graph_compiling=config.enable_graph, + enable_chunk_prefill_graph=config.enable_chunk_prefill_graph, attention_backend=config.attn_backend, ) @@ -167,6 +170,8 @@ def _init_device(self): def add_request(self, request: InferenceRequest): """Add a request to the scheduler.""" + if self.cache_type == "paged" and self.config.chunk_size > 0: + request.chunk_size = self.config.chunk_size self.scheduler.add_request(request) def step(self) -> tuple[list[InferenceRequest], list[tuple]]: @@ -210,7 +215,18 @@ def _update_requests( sampled_tokens: List[int], ) -> List[tuple]: """Update request status after inference step.""" - if is_prefill: + # Detect a chunked-prefill mid-step: single request, prefill phase, + # and this chunk does not yet cover the whole prompt. In that case + # we must NOT consume a sampled token, NOT commit prefill blocks, + # and re-enqueue the request to keep chunking. + chunk_mid_step = ( + is_prefill + and len(requests) == 1 + and requests[0].is_chunking() + and not requests[0].chunk_is_last() + ) + + if is_prefill and not chunk_mid_step: match self.cache_type: case "paged": self.scheduler.cache_manager.reset_req_blocks() @@ -218,6 +234,20 @@ def _update_requests( self.scheduler.update_cache() case _: raise ValueError(f"Unsupported cache_type: {self.cache_type}") + + if chunk_mid_step: + req = requests[0] + req.chunk_prefill_offset += req.chunk_size + # If this request was aborted while chunking, drop it. + if req.is_aborted(): + logger.info( + f"Request {req.request_id} aborted by client during chunked-prefill" + ) + return [] + # Re-enqueue to keep producing chunks; no token sampled yet. + self.scheduler.requeue_chunking(req) + return [] + pending = [] for req, token_id in zip(requests, sampled_tokens): if req.is_aborted(): @@ -227,6 +257,10 @@ def _update_requests( continue if req.is_prefill: + # Clean up chunked-prefill state on the final chunk so the + # next forward pass on this request takes the decode path. + req.chunk_prefill_offset = 0 + req.chunk_size = 0 req.is_prefill = False req.generated_token_ids.append(token_id) @@ -361,6 +395,8 @@ def __init__( top_p: float = 0.8, top_k: int = 1, enable_graph: bool = False, + enable_chunk_prefill_graph: bool = False, + chunk_size: int = 0, attn_backend: str = "default", skip_load: bool = False, ): @@ -398,6 +434,8 @@ def __init__( top_p=top_p, top_k=top_k, enable_graph=enable_graph, + enable_chunk_prefill_graph=enable_chunk_prefill_graph, + chunk_size=chunk_size, attn_backend=attn_backend, skip_load=skip_load, ) @@ -539,6 +577,8 @@ def __init__( top_p: float = 0.8, top_k: int = 1, enable_graph: bool = False, + enable_chunk_prefill_graph: bool = False, + chunk_size: int = 0, attn_backend: str = "default", ): """Initialize AsyncLLMEngine. @@ -575,6 +615,8 @@ def __init__( top_p=top_p, top_k=top_k, enable_graph=enable_graph, + enable_chunk_prefill_graph=enable_chunk_prefill_graph, + chunk_size=chunk_size, attn_backend=attn_backend, ) self.engine = LLMEngine(config) diff --git a/python/infinilm/llm/request.py b/python/infinilm/llm/request.py index 15bcf69f..679b6e4d 100644 --- a/python/infinilm/llm/request.py +++ b/python/infinilm/llm/request.py @@ -144,6 +144,11 @@ def __init__( self.num_cached_tokens: int = 0 self.num_blocks: int = 0 + # Chunked-prefill state (0 = disabled, otherwise tokens per chunk) + self.chunk_size: int = 0 + # Number of prompt tokens already fed through forward as chunked-prefill + self.chunk_prefill_offset: int = 0 + # For server use self.request_data: Optional[dict] = request_data self.http_request: Optional[Any] = http_request @@ -186,6 +191,18 @@ def get_num_blocks_required(self, block_size: int) -> int: def get_max_tokens(self) -> Optional[int]: return self.sampling_params.max_tokens + def is_chunking(self) -> bool: + """Return True if this request is in the middle of chunked-prefill.""" + return ( + self.chunk_size > 0 + and self.is_prefill + and self.prompt_length > self.chunk_size + ) + + def chunk_is_last(self) -> bool: + """Return True if the next chunk would finish the prompt.""" + return self.chunk_prefill_offset + self.chunk_size >= self.prompt_length + def is_finished(self) -> bool: return self.status in [ RequestStatus.FINISHED, diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index f9c11635..95a84480 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -42,6 +42,9 @@ def __init__( ): self.waiting_queue = janus.Queue() self.running_queue = janus.Queue() + # Requests in the middle of chunked-prefill — scheduled at high priority, + # single-request batches only (to match the C++ ChunkPrefillCompiler graph signature). + self.chunking_queue = janus.Queue() self.max_batch_size = max_batch_size self.cache_manager = BlockManager(num_blocks=num_blocks, block_size=block_size) @@ -53,7 +56,27 @@ def add_request(self, request: InferenceRequest): self.waiting_queue.sync_q.put(request) def schedule(self) -> Optional[SchedulerOutput]: - """Schedule and return batch of requests to execute.""" + """Schedule and return batch of requests to execute. + + Priority (mirrors launch_server.py chunked-prefill scheduling): + 1. Running queue (decode) — short / latency-sensitive + 2. Chunking queue (in-flight chunked-prefill) — single-request slice + 3. Waiting queue (new prefill) — may start chunking if prompt is long + """ + # 2) Continue an in-flight chunked-prefill request (single-request batch). + try: + req = self.chunking_queue.sync_q.get_nowait() + except queue.Empty: + req = None + if req is not None: + if req.is_finished(): + self.complete_requests([req]) + else: + return SchedulerOutput( + scheduled_requests=[req], + is_prefill=True, + ) + scheduled_requests = [] is_prefill = False @@ -91,6 +114,18 @@ def schedule(self) -> Optional[SchedulerOutput]: req.num_blocks = len(req.block_table) req.status = RequestStatus.RUNNING + + # Start chunked-prefill: enqueue into chunking_queue and emit a + # single-request batch immediately. We don't mix chunked-prefill + # with other requests in the same batch — the C++ ChunkPrefillCompiler + # graph is keyed on (batch_size, chunk_size). + if req.chunk_size > 0 and req.prompt_length > req.chunk_size: + req.chunk_prefill_offset = 0 + return SchedulerOutput( + scheduled_requests=[req], + is_prefill=True, + ) + scheduled_requests.append(req) # Return prefill batch if any waiting requests were scheduled @@ -135,6 +170,10 @@ def schedule(self) -> Optional[SchedulerOutput]: return None + def requeue_chunking(self, req: InferenceRequest): + """Put a request back into the chunking queue after a chunk has run.""" + self.chunking_queue.sync_q.put(req) + def complete_requests(self, requests: List[InferenceRequest]): """Handle completed requests and free their blocks.""" for req in requests: diff --git a/python/infinilm/processors/basic_llm_processor.py b/python/infinilm/processors/basic_llm_processor.py index 070a4062..f5e603ba 100644 --- a/python/infinilm/processors/basic_llm_processor.py +++ b/python/infinilm/processors/basic_llm_processor.py @@ -185,19 +185,39 @@ def _build_model_input_from_batch_scheduler_output( if scheduler_output.is_prefill: # Prefill phase req_tokens = req.get_input_tokens() - tokens_to_compute = req_tokens[num_cached:] - tokens.extend(tokens_to_compute) - compute_len = len(tokens_to_compute) - seq_len = len(req_tokens) - seq_lens.append(seq_len) + # Chunked-prefill: only feed [chunk_prefill_offset : +chunk_size). + # past_kv_lengths = chunk_prefill_offset (attention sees the prefix + # already committed); total_kv_lengths = chunk_prefill_offset + + # len(tokens_to_compute). This keeps batch_size=1 and total_tokens + # == chunk_size so the C++ ChunkPrefillCompiler graph hits. + if req.is_chunking(): + start = req.chunk_prefill_offset + end = min(start + req.chunk_size, len(req_tokens)) + tokens_to_compute = req_tokens[start:end] + compute_len = len(tokens_to_compute) + tokens.extend(tokens_to_compute) + seq_len = end # attention prefix length after this chunk + seq_lens.append(seq_len) + current_offset += compute_len + seq_offsets.append(current_offset) + slot_mapping.extend(req.slot_mapping[start:end]) + cached_lens.append(start) + position_ids.extend(range(start, end)) + else: + tokens_to_compute = req_tokens[num_cached:] + tokens.extend(tokens_to_compute) - current_offset += compute_len - seq_offsets.append(current_offset) + compute_len = len(tokens_to_compute) + seq_len = len(req_tokens) + seq_lens.append(seq_len) - slot_mapping.extend(req.slot_mapping) - cached_lens.append(num_cached) - position_ids.extend(range(num_cached, num_cached + compute_len)) + current_offset += compute_len + seq_offsets.append(current_offset) + + slot_mapping.extend(req.slot_mapping) + cached_lens.append(num_cached) + position_ids.extend(range(num_cached, num_cached + compute_len)) else: # Decode phase diff --git a/python/infinilm/server/inference_server.py b/python/infinilm/server/inference_server.py index 71e9c992..ac7e94e7 100644 --- a/python/infinilm/server/inference_server.py +++ b/python/infinilm/server/inference_server.py @@ -108,6 +108,8 @@ def __init__( host: str = "0.0.0.0", port: int = 8000, enable_graph: bool = False, + enable_chunk_prefill_graph: bool = False, + chunk_size: int = 0, attn_backend: str = "default", ignore_eos: bool = False, ): @@ -130,6 +132,10 @@ def __init__( host: Server host address. port: Server port number. enable_graph: Whether to enable graph compiling. + enable_chunk_prefill_graph: Whether to enable chunk-prefill graph compiling. + chunk_size: Tokens per chunked-prefill slice (0 = disabled). When > 0 and paged + cache is used, long prompts are sliced and each slice goes through forward + separately so the C++ ChunkPrefillCompiler precompiled graph can be reused. attn_backend: Attention backend to use ('default', 'flash-attn'). """ self.model_path = model_path @@ -150,6 +156,8 @@ def __init__( self.host = host self.port = port self.enable_graph = enable_graph + self.enable_chunk_prefill_graph = enable_chunk_prefill_graph + self.chunk_size = chunk_size self.attn_backend = attn_backend self.ignore_eos = ignore_eos @@ -182,11 +190,15 @@ async def lifespan(app: FastAPI): top_p=self.top_p, top_k=self.top_k, enable_graph=self.enable_graph, + enable_chunk_prefill_graph=self.enable_chunk_prefill_graph, + chunk_size=self.chunk_size, attn_backend=self.attn_backend, ) self.engine.start() logger.info(f"Engine initialized with model at {self.model_path}") logger.info(f" enable_graph: {self.enable_graph}") + logger.info(f" enable_chunk_prefill_graph: {self.enable_chunk_prefill_graph}") + logger.info(f" chunk_size: {self.chunk_size}") yield self.engine.stop() @@ -572,6 +584,8 @@ def main(): host=cfg.host, port=cfg.port, enable_graph=cfg.enable_graph, + enable_chunk_prefill_graph=cfg.enable_chunk_prefill_graph, + chunk_size=cfg.chunk_size, attn_backend=cfg.attn, ignore_eos=cfg.ignore_eos, ) diff --git a/scripts/infer_task.py b/scripts/infer_task.py index 0d1231b7..1851f0a0 100644 --- a/scripts/infer_task.py +++ b/scripts/infer_task.py @@ -10,6 +10,8 @@ def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): self.end_tokens = end_tokens self._kv_cache = None self.pos = 0 + self._discard_output = False + self._remaining_tokens = None def bind_kvcache(self, kv_cache, pos=0): self._kv_cache = kv_cache @@ -24,6 +26,25 @@ def release_kvcache(self): def kvcache(self): return self._kv_cache + def setup_chunked_prefill(self, chunk_size): + if chunk_size <= 0 or len(self.tokens) <= chunk_size: + return + self._remaining_tokens = self.tokens[chunk_size:] + self.tokens = self.tokens[:chunk_size] + self._discard_output = True + + def advance_prefill_chunk(self, chunk_size): + self._kv_cache.update_tokens(self.tokens, self.pos) + self.pos += len(self.tokens) + + if len(self._remaining_tokens) <= chunk_size: + self.tokens = self._remaining_tokens + self._remaining_tokens = None + self._discard_output = False + else: + self.tokens = self._remaining_tokens[:chunk_size] + self._remaining_tokens = self._remaining_tokens[chunk_size:] + def next(self, out_token): self._kv_cache.update_tokens(self.tokens, self.pos) diff --git a/scripts/launch_server.py b/scripts/launch_server.py index d04d4f69..0639a28b 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -64,6 +64,13 @@ def parse_args(): default=None, help="Max token sequence length that model will handle (follows model config if not provided)", ) + parser.add_argument( + "--chunk-size", + type=int, + default=512, + help="Maximum number of tokens per prefill chunk (default: 512). " + "Set to 0 to disable chunked prefill.", + ) parser.add_argument( "--awq", action="store_true", @@ -86,8 +93,10 @@ def parse_args(): USE_AWQ = args.awq USE_GPTQ = args.gptq MAX_BATCH = args.max_batch +CHUNK_SIZE = args.chunk_size print( - f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs." + f"Using MAX_BATCH={MAX_BATCH}, CHUNK_SIZE={CHUNK_SIZE}. " + f"Try reduce these values if out of memory error occurs." ) @@ -163,32 +172,66 @@ async def lifespan(app: FastAPI): # App loop: take requests from the queue, do inference, and put unfinished requests back into the queue. +# Uses priority scheduling: decode/short tasks first, then prefill chunks. def worker_loop(app): + pending_prefill = [] # Low priority: chunked prefill tasks + while True: + # Drain all available tasks from the queue + incoming = [] try: task = app.state.request_queue.sync_q.get(timeout=0.01) + if task is None: + return + incoming.append(task) except queue.Empty: - continue - - if task is None: - return + pass - batch = [task] - while len(batch) < MAX_BATCH: + while True: try: - req = app.state.request_queue.sync_q.get_nowait() - if req is not None: - batch.append(req) + task = app.state.request_queue.sync_q.get_nowait() + if task is None: + return + incoming.append(task) except queue.Empty: break + + # Separate into high priority (decode/new short) and low priority (prefill chunks) + high_priority = [] + for t in incoming: + if t._discard_output: + pending_prefill.append(t) + else: + high_priority.append(t) + + # Build batch: high priority first, then fill with prefill chunks + batch = [] + while high_priority and len(batch) < MAX_BATCH: + batch.append(high_priority.pop(0)) + while pending_prefill and len(batch) < MAX_BATCH: + batch.append(pending_prefill.pop(0)) + + if not batch: + continue + output_tokens = app.state.model.batch_infer_one_round(batch) for task, token in zip(batch, output_tokens): - task.output(token) - if task.finish_reason is None: - app.state.request_queue.sync_q.put(task) + if task._discard_output: + task.advance_prefill_chunk(CHUNK_SIZE) + if task.finish_reason is None: + if task._discard_output: + pending_prefill.append(task) + else: + app.state.request_queue.sync_q.put(task) + else: + app.state.kv_cache_pool.release_sync(task) else: - print(f"[INFO] Task {task.id} finished infer.") - app.state.kv_cache_pool.release_sync(task) + task.output(token) + if task.finish_reason is None: + app.state.request_queue.sync_q.put(task) + else: + print(f"[INFO] Task {task.id} finished infer.") + app.state.kv_cache_pool.release_sync(task) def build_task(id_, request_data, request: Request): @@ -214,6 +257,7 @@ async def chat_stream(id_, request_data, request: Request): try: infer_task = build_task(id_, request_data, request) await request.app.state.kv_cache_pool.acquire(infer_task) + infer_task.setup_chunked_prefill(CHUNK_SIZE) # Initial empty content chunk = json.dumps( @@ -255,6 +299,7 @@ async def chat(id_, request_data, request: Request): try: infer_task = build_task(id_, request_data, request) await request.app.state.kv_cache_pool.acquire(infer_task) + infer_task.setup_chunked_prefill(CHUNK_SIZE) request.app.state.request_queue.sync_q.put(infer_task) output = [] while True: From bb68ca563604f079cfd35d3f5509b8e56ac654bf Mon Sep 17 00:00:00 2001 From: huidesheng <1832140001@qq.com> Date: Wed, 13 May 2026 11:15:15 +0800 Subject: [PATCH 2/2] add chunk_prefill_compiler.cpp/.hpp --- .../compiler/chunk_prefill_compiler.cpp | 186 ++++++++++++++++++ .../compiler/chunk_prefill_compiler.hpp | 42 ++++ 2 files changed, 228 insertions(+) create mode 100644 csrc/engine/compiler/chunk_prefill_compiler.cpp create mode 100644 csrc/engine/compiler/chunk_prefill_compiler.hpp diff --git a/csrc/engine/compiler/chunk_prefill_compiler.cpp b/csrc/engine/compiler/chunk_prefill_compiler.cpp new file mode 100644 index 00000000..266bd0e7 --- /dev/null +++ b/csrc/engine/compiler/chunk_prefill_compiler.cpp @@ -0,0 +1,186 @@ +#include "chunk_prefill_compiler.hpp" +#include "infinicore/context/context.hpp" + + +namespace { +inline void set_zeros(infinicore::Tensor &tensor) { + std::vector zeros(tensor->nbytes(), 0); + infinicore::context::memcpyH2D(tensor->data(), zeros.data(), tensor->nbytes(), false); +} +} // namespace + +namespace infinilm::engine { + +ChunkPrefillCompiler::ChunkPrefillCompiler(const std::shared_ptr &model, RankBarrier *barrier) + : GraphCompiler(model, barrier) { + // Enumerate chunk sizes for chunk-prefill + for (size_t cs : {64, 128, 256, 512, 1024, 2048}) { + chunk_sizes_.push_back(cs); + } + // Enumerate batch sizes for prefill (typically smaller than decode) + for (size_t b = 1; b < 32; b++) { + prefill_batch_sizes_.push_back(b); + } + for (size_t b = 32; b < 64; b += 8) { + prefill_batch_sizes_.push_back(b); + } + for (size_t b = 64; b < 128; b += 16) { + prefill_batch_sizes_.push_back(b); + } + for (size_t b = 128; b < 256; b += 32) { + prefill_batch_sizes_.push_back(b); + } + for (size_t b = 256; b <= 512; b += 64) { + prefill_batch_sizes_.push_back(b); + } +} + +void ChunkPrefillCompiler::compile() { + if (model_->get_cache_config() != nullptr && + dynamic_cast(model_->get_cache_config())) { + + const auto *paged_config = + dynamic_cast(model_->get_cache_config()); + size_t nblocks = paged_config->num_blocks(); + + compiled_map_prefill_.clear(); + + // Max total tokens to avoid OOM during graph recording + constexpr size_t MAX_TOTAL_TOKENS = 4096; + + // Pre-allocate a shared block_tables_holder for the largest (batch_size) we'll use + size_t max_batch = *std::max_element(prefill_batch_sizes_.begin(), prefill_batch_sizes_.end()); + size_t block_per_req = nblocks / max_batch; + block_tables_holder_ = infinicore::Tensor::empty( + {nblocks}, infinicore::DataType::I32, infinicore::context::getDevice()); + set_zeros(block_tables_holder_); + + for (size_t b : prefill_batch_sizes_) { + for (size_t cs : chunk_sizes_) { + size_t total_tokens = b * cs; + if (total_tokens > MAX_TOTAL_TOKENS) { + continue; + } + + size_t bpr = nblocks / b; // block_per_req for this batch size + + InfinilmModel::Input input; + + // input_ids: [1, total_tokens] — all tokens for this batch packed together + input.input_ids = infinicore::Tensor::empty( + {1, total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice()); + set_zeros(input.input_ids.value()); + + // position_ids: [total_tokens] + input.position_ids = infinicore::Tensor::empty( + {total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice()); + set_zeros(input.position_ids.value()); + + // total_sequence_lengths: [b], set to cs (first-chunk scenario) + input.total_sequence_lengths = infinicore::Tensor::empty( + {b}, infinicore::DataType::I32, infinicore::context::getDevice()); + { + std::vector tsl(b, static_cast(cs)); + infinicore::context::memcpyH2D( + input.total_sequence_lengths.value()->data(), + tsl.data(), b * sizeof(int32_t), false); + } + + // input_offsets: [b+1], stride = cs + input.input_offsets = infinicore::Tensor::empty( + {b + 1}, infinicore::DataType::I32, infinicore::context::getDevice()); + { + std::vector offsets(b + 1); + for (size_t i = 0; i <= b; i++) { + offsets[i] = static_cast(i * cs); + } + infinicore::context::memcpyH2D( + input.input_offsets.value()->data(), + offsets.data(), (b + 1) * sizeof(int32_t), false); + } + + // cu_seqlens: [b+1], same layout as input_offsets for prefill + input.cu_seqlens = infinicore::Tensor::empty( + {b + 1}, infinicore::DataType::I32, infinicore::context::getDevice()); + { + std::vector cu(b + 1); + for (size_t i = 0; i <= b; i++) { + cu[i] = static_cast(i * cs); + } + infinicore::context::memcpyH2D( + input.cu_seqlens.value()->data(), + cu.data(), (b + 1) * sizeof(int32_t), false); + } + + // block_tables: view into the shared holder [b, bpr] + input.block_tables = block_tables_holder_->as_strided( + {b, bpr}, {(ptrdiff_t)bpr, 1}); + + // slot_mapping: [total_tokens] + input.slot_mapping = infinicore::Tensor::empty( + {total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice()); + set_zeros(input.slot_mapping.value()); + + barrier_->wait(); + infinicore::context::startGraphRecording(); + auto output = model_->forward(input); + auto graph = infinicore::context::stopGraphRecording(); + barrier_->wait(); + + auto shared_output = std::shared_ptr( + new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)}); + + compiled_map_prefill_[std::make_tuple(b, cs)] = + CompiledResult{std::move(input), std::make_tuple(graph, shared_output)}; + } + } + } +} + +ChunkPrefillCompiler::Compiled ChunkPrefillCompiler::get_compiled(const InfinilmModel::Input &input) { + if (model_->get_cache_config() == nullptr || + !dynamic_cast(model_->get_cache_config())) { + return {nullptr, nullptr}; + } + + if (!input.block_tables.has_value() || !input.input_ids.has_value()) { + return {nullptr, nullptr}; + } + + size_t batch_size = input.block_tables.value()->size(0); + size_t block_per_req = input.block_tables.value()->size(1); + size_t total_tokens = input.input_ids.value()->size(1); + + // Prefill: total_tokens is a multiple of batch_size, and chunk_size > 1 + if (total_tokens == 0 || total_tokens % batch_size != 0) { + return {nullptr, nullptr}; + } + size_t chunk_size = total_tokens / batch_size; + if (chunk_size <= 1) { + // Single-token case belongs to decode + return {nullptr, nullptr}; + } + + auto result = compiled_map_prefill_.find(std::make_tuple(batch_size, chunk_size)); + if (result == compiled_map_prefill_.end()) { + return {nullptr, nullptr}; + } + + auto &graph_input = result->second.input; + + graph_input.input_ids.value()->copy_from(input.input_ids.value()); + graph_input.position_ids.value()->copy_from(input.position_ids.value()); + graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value()); + graph_input.input_offsets.value()->copy_from(input.input_offsets.value()); + graph_input.cu_seqlens.value()->copy_from(input.cu_seqlens.value()); + graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value()); + graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value()); + + auto graph = std::get<0>(result->second.compiled); + auto shared_output = std::shared_ptr( + new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()}); + + return std::make_tuple(graph, shared_output); +} + +} // namespace infinilm::engine diff --git a/csrc/engine/compiler/chunk_prefill_compiler.hpp b/csrc/engine/compiler/chunk_prefill_compiler.hpp new file mode 100644 index 00000000..bd701158 --- /dev/null +++ b/csrc/engine/compiler/chunk_prefill_compiler.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include "graph_compiler.hpp" + +#include + +namespace infinilm::engine { +class ChunkPrefillCompiler : public GraphCompiler { +public: + ChunkPrefillCompiler(const std::shared_ptr &model, RankBarrier *barrier); + + void compile() override; + + Compiled get_compiled(const InfinilmModel::Input &input) override; + +private: + struct TupleHash { + size_t operator()(const std::tuple &t) const noexcept { + auto h1 = std::hash{}(std::get<0>(t)); + auto h2 = std::hash{}(std::get<1>(t)); + return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2)); + } + }; + + std::vector chunk_sizes_; + std::vector prefill_batch_sizes_; + + infinicore::Tensor block_tables_holder_; + + struct CompiledResult { + InfinilmModel::Input input; + Compiled compiled; + }; + + // Key: (batch_size, chunk_size) + std::unordered_map< + std::tuple, + CompiledResult, + TupleHash> + compiled_map_prefill_; +}; +} // namespace infinilm::engine