diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 4aeec8af..d84d223f 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -67,7 +67,11 @@ inline void bind_infer_engine(py::module &m) { return state_dict_tp_all; }) .def( - "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") + "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { + py::gil_scoped_release release; + return self.forward(input); + }, + "Run inference on all ranks with arbitrary arguments") .def( "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr { @@ -113,7 +117,11 @@ inline void bind_infer_engine(py::module &m) { return state_dict_tp_all; }) .def( - "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") + "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { + py::gil_scoped_release release; + return self.forward(input); + }, + "Run inference on all ranks with arbitrary arguments") .def( "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) { diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index 354c3902..7b6ceea4 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -6,6 +6,7 @@ - AsyncLLM class for asynchronous streaming (server use) """ +import asyncio import time import uuid import logging @@ -189,16 +190,18 @@ def add_request(self, request: InferenceRequest): """Add a request to the scheduler.""" self.scheduler.add_request(request) - def step(self) -> List[InferenceRequest]: + def step(self) -> tuple[list[InferenceRequest], list[tuple]]: """Run one inference step. Returns: - List of requests that were processed in this step. + A tuple of: + - scheduled_requests: Requests that were scheduled and processed in this step. + - pending: Pending streaming outputs as (async_queue, TokenOutput) pairs. """ # Schedule requests scheduler_output = self.scheduler.schedule() if scheduler_output is None or not scheduler_output.scheduled_requests: - return [] + return [], [] # Build model inputs model_input_dict = scheduler_output.build_model_inputs( @@ -211,13 +214,13 @@ def step(self) -> List[InferenceRequest]: sampled_tokens_list = sampled_tokens.to_numpy().tolist() # Update request status - self._update_requests( + pending = self._update_requests( scheduler_output.is_prefill, scheduler_output.scheduled_requests, sampled_tokens_list, ) - return scheduler_output.scheduled_requests + return scheduler_output.scheduled_requests, pending def _prepare_model_input(self, model_input_dict: dict) -> dict: """Convert model input dict to infinicore tensors.""" @@ -246,7 +249,7 @@ def _update_requests( is_prefill: bool, requests: List[InferenceRequest], sampled_tokens: List[int], - ): + ) -> List[tuple]: """Update request status after inference step.""" if is_prefill: match self.cache_type: @@ -256,9 +259,8 @@ def _update_requests( self.scheduler.update_cache() case _: raise ValueError(f"Unsupported cache_type: {self.cache_type}") - + pending = [] for req, token_id in zip(requests, sampled_tokens): - if req.is_aborted(): logger.info( f"Request {req.request_id} aborted by client, skipping update" @@ -320,16 +322,10 @@ def _update_requests( f"Request {req.request_id} aborted before putting token" ) continue - try: - req.output_queue.sync_q.put(output) - except Exception as e: - logger.warning( - f"Failed to put token for {req.request_id}: {e}. " - f"Likely due to client disconnecting or request cancelation." - ) - continue + pending.append((req.output_queue.async_q, output)) self.scheduler.complete_requests(requests) + return pending def _check_request_finished(self, req: InferenceRequest, token_id: int) -> bool: """Check if request generation is finished.""" @@ -597,6 +593,7 @@ def __init__( self._running = False self._step_thread: Optional[threading.Thread] = None + self._loop: Optional[asyncio.AbstractEventLoop] = None self._healthy = True def is_healthy(self) -> bool: @@ -608,6 +605,7 @@ def start(self): logger.warning("AsyncLLMEngine is already running") return + self._loop = asyncio.get_running_loop() self._running = True self._step_thread = threading.Thread( target=self._step_loop, daemon=True, name="AsyncLLMEngineStepThread" @@ -630,15 +628,28 @@ def _step_loop(self): """Background loop that runs inference steps.""" while self._running: try: - requests = self.engine.step() + requests, pending = self.engine.step() if not requests: time.sleep(0.01) + elif pending: + self._loop.call_soon_threadsafe(self._batch_put, pending) except Exception as e: logger.error(f"Error in step loop: {e}", exc_info=True) self._healthy = False self._running = False break + @staticmethod + def _batch_put(pending): + for async_q, output in pending: + try: + async_q.put_nowait(output) + except Exception as e: + logger.warning( + f"Failed to put token for request {output.request_id}: {e}. " + f"Likely due to client disconnecting or request cancelation." + ) + def add_request( self, prompt: Optional[str] = None, diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index 33ea67d0..92691416 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -59,6 +59,7 @@ def build_model_inputs( slot_mapping = [] cached_lens = [] position_ids = [] + cu_seqlens = [0] max_block_table_len = max( len(req.block_table) for req in self.scheduled_requests @@ -73,37 +74,37 @@ def build_model_inputs( tokens_to_compute = req_tokens[num_cached:] tokens.extend(tokens_to_compute) - seq_len = len(tokens_to_compute) - seq_lens.append(len(req_tokens)) + compute_len = len(tokens_to_compute) + seq_len = len(req_tokens) + seq_lens.append(seq_len) - current_offset += seq_len + current_offset += compute_len seq_offsets.append(current_offset) slot_mapping.extend(req.slot_mapping) cached_lens.append(num_cached) - position_ids.extend(range(num_cached, num_cached + seq_len)) + position_ids.extend(range(num_cached, num_cached + compute_len)) else: # Decode phase + seq_len = req.get_total_length() last_token = req.generated_token_ids[-1] tokens.append(last_token) - seq_lens.append(req.get_total_length()) + seq_lens.append(seq_len) current_offset += 1 seq_offsets.append(current_offset) slot_mapping.extend(req.slot_mapping) cached_lens.append(num_cached) - position_ids.append(req.get_total_length() - 1) + position_ids.append(seq_len - 1) # Pad block_table to same length padded_block_table = req.block_table + [-1] * ( max_block_table_len - len(req.block_table) ) block_tables.append(padded_block_table) - cu_seqlens = [0] - for l in seq_lens: - cu_seqlens.append(cu_seqlens[-1] + l) + cu_seqlens.append(cu_seqlens[-1] + seq_len) return { "input_ids": [tokens],