diff --git a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py index 2e50cae85..e1bdf5ab8 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py @@ -1,4 +1,5 @@ import uuid +import time from typing import List from ...batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue @@ -69,6 +70,22 @@ def _can_add_new_group_reqs(self, cur_handle_group_reqs: List[Req], is_busy, new else: return False, new_batch_first_router_need_tokens + def _filter_aborted_reqs(self): + # 先移除在等待队列中已经处于aborted状态的请求, 如果发现存在aborted的请求, + # 则休眠10ms,保证httpserver将所有属于一组的请求都置为aborted请求,再将 + # 请求从队列中移除。 + exist_aborted_req = len([req for req in self.waiting_req_list if req.is_aborted]) > 0 + if exist_aborted_req: + time.sleep(0.01) + aborted_reqs = [req for req in self.waiting_req_list if req.is_aborted] + self.waiting_req_list = [req for req in self.waiting_req_list if not req.is_aborted] + for req in aborted_reqs: + req: Req = req + logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}") + self.free_aborted_req_cpu_cache_pages(req) + self.router.shm_req_manager.put_back_req_obj(req) + return + # @calculate_time(show=True, min_cost_ms=10) def generate_new_batch(self, current_batch: Batch): if len(self.waiting_req_list) == 0: @@ -79,6 +96,8 @@ def generate_new_batch(self, current_batch: Batch): req_is_full = exist_req_num >= self.running_max_req_size if req_is_full: return None + + self._filter_aborted_reqs() if len(self.waiting_req_list) == 0: return None @@ -87,15 +106,9 @@ def generate_new_batch(self, current_batch: Batch): self._init_cache_list(current_batch, is_busy) can_run_list = [] - abort_req_list = [] new_batch_first_router_need_tokens = 0 # 主要是对 prefill 大块计算时候的token数量限制 - aborted_count = 0 cur_group_reqs = [] for req in self.waiting_req_list: - if req.is_aborted: - aborted_count += 1 - abort_req_list.append(req) - continue if self._add_to_group(cur_group_reqs, req): continue @@ -121,12 +134,7 @@ def generate_new_batch(self, current_batch: Batch): if len(can_run_list) != 0: new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) - for req in abort_req_list: - req: Req = req - logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}") - self.free_aborted_req_cpu_cache_pages(req) - self.router.shm_req_manager.put_back_req_obj(req) - self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] + self.waiting_req_list = self.waiting_req_list[len(can_run_list) :] return new_batch def _add_to_group(self, cur_group_reqs, req: Req):