Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions csrc/pybind11/engine/engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ inline void bind_infer_engine(py::module &m) {
return state_dict_tp_all;
})
.def(
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output {
py::gil_scoped_release release;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

py::gil_scoped_release release; 这是什么

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

释放Python的GIL

return self.forward(input);
},
"Run inference on all ranks with arbitrary arguments")
.def(
"reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr<cache::CacheConfig> {
Expand Down Expand Up @@ -113,7 +117,11 @@ inline void bind_infer_engine(py::module &m) {
return state_dict_tp_all;
})
.def(
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output {
py::gil_scoped_release release;
return self.forward(input);
},
"Run inference on all ranks with arbitrary arguments")
.def(
"reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) {
Expand Down
45 changes: 28 additions & 17 deletions python/infinilm/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- AsyncLLM class for asynchronous streaming (server use)
"""

import asyncio
import time
import uuid
import logging
Expand Down Expand Up @@ -189,16 +190,18 @@ def add_request(self, request: InferenceRequest):
"""Add a request to the scheduler."""
self.scheduler.add_request(request)

def step(self) -> List[InferenceRequest]:
def step(self) -> tuple[list[InferenceRequest], list[tuple]]:
"""Run one inference step.

Returns:
List of requests that were processed in this step.
A tuple of:
- scheduled_requests: Requests that were scheduled and processed in this step.
- pending: Pending streaming outputs as (async_queue, TokenOutput) pairs.
"""
# Schedule requests
scheduler_output = self.scheduler.schedule()
if scheduler_output is None or not scheduler_output.scheduled_requests:
return []
return [], []

# Build model inputs
model_input_dict = scheduler_output.build_model_inputs(
Expand All @@ -211,13 +214,13 @@ def step(self) -> List[InferenceRequest]:
sampled_tokens_list = sampled_tokens.to_numpy().tolist()

# Update request status
self._update_requests(
pending = self._update_requests(
scheduler_output.is_prefill,
scheduler_output.scheduled_requests,
sampled_tokens_list,
)

return scheduler_output.scheduled_requests
return scheduler_output.scheduled_requests, pending

def _prepare_model_input(self, model_input_dict: dict) -> dict:
"""Convert model input dict to infinicore tensors."""
Expand Down Expand Up @@ -246,7 +249,7 @@ def _update_requests(
is_prefill: bool,
requests: List[InferenceRequest],
sampled_tokens: List[int],
):
) -> List[tuple]:
"""Update request status after inference step."""
if is_prefill:
match self.cache_type:
Expand All @@ -256,9 +259,8 @@ def _update_requests(
self.scheduler.update_cache()
case _:
raise ValueError(f"Unsupported cache_type: {self.cache_type}")

pending = []
for req, token_id in zip(requests, sampled_tokens):

if req.is_aborted():
logger.info(
f"Request {req.request_id} aborted by client, skipping update"
Expand Down Expand Up @@ -320,16 +322,10 @@ def _update_requests(
f"Request {req.request_id} aborted before putting token"
)
continue
try:
req.output_queue.sync_q.put(output)
except Exception as e:
logger.warning(
f"Failed to put token for {req.request_id}: {e}. "
f"Likely due to client disconnecting or request cancelation."
)
continue
pending.append((req.output_queue.async_q, output))

self.scheduler.complete_requests(requests)
return pending

def _check_request_finished(self, req: InferenceRequest, token_id: int) -> bool:
"""Check if request generation is finished."""
Expand Down Expand Up @@ -597,6 +593,7 @@ def __init__(

self._running = False
self._step_thread: Optional[threading.Thread] = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._healthy = True

def is_healthy(self) -> bool:
Expand All @@ -608,6 +605,7 @@ def start(self):
logger.warning("AsyncLLMEngine is already running")
return

self._loop = asyncio.get_running_loop()
self._running = True
self._step_thread = threading.Thread(
target=self._step_loop, daemon=True, name="AsyncLLMEngineStepThread"
Expand All @@ -630,15 +628,28 @@ def _step_loop(self):
"""Background loop that runs inference steps."""
while self._running:
try:
requests = self.engine.step()
requests, pending = self.engine.step()
if not requests:
time.sleep(0.01)
elif pending:
self._loop.call_soon_threadsafe(self._batch_put, pending)
except Exception as e:
logger.error(f"Error in step loop: {e}", exc_info=True)
self._healthy = False
self._running = False
break

@staticmethod
def _batch_put(pending):
for async_q, output in pending:
try:
async_q.put_nowait(output)
except Exception as e:
logger.warning(
f"Failed to put token for request {output.request_id}: {e}. "
f"Likely due to client disconnecting or request cancelation."
)

def add_request(
self,
prompt: Optional[str] = None,
Expand Down
19 changes: 10 additions & 9 deletions python/infinilm/llm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def build_model_inputs(
slot_mapping = []
cached_lens = []
position_ids = []
cu_seqlens = [0]

max_block_table_len = max(
len(req.block_table) for req in self.scheduled_requests
Expand All @@ -73,37 +74,37 @@ def build_model_inputs(
tokens_to_compute = req_tokens[num_cached:]
tokens.extend(tokens_to_compute)

seq_len = len(tokens_to_compute)
seq_lens.append(len(req_tokens))
compute_len = len(tokens_to_compute)
seq_len = len(req_tokens)
seq_lens.append(seq_len)

current_offset += seq_len
current_offset += compute_len
seq_offsets.append(current_offset)

slot_mapping.extend(req.slot_mapping)
cached_lens.append(num_cached)
position_ids.extend(range(num_cached, num_cached + seq_len))
position_ids.extend(range(num_cached, num_cached + compute_len))

else:
# Decode phase
seq_len = req.get_total_length()
last_token = req.generated_token_ids[-1]
tokens.append(last_token)
seq_lens.append(req.get_total_length())
seq_lens.append(seq_len)

current_offset += 1
seq_offsets.append(current_offset)

slot_mapping.extend(req.slot_mapping)
cached_lens.append(num_cached)
position_ids.append(req.get_total_length() - 1)
position_ids.append(seq_len - 1)

# Pad block_table to same length
padded_block_table = req.block_table + [-1] * (
max_block_table_len - len(req.block_table)
)
block_tables.append(padded_block_table)
cu_seqlens = [0]
for l in seq_lens:
cu_seqlens.append(cu_seqlens[-1] + l)
cu_seqlens.append(cu_seqlens[-1] + seq_len)

return {
"input_ids": [tokens],
Expand Down