Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions lightllm/server/router/req_queue/chunked_prefill/beam_impl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
import time
from typing import List
from ...batch import Batch, Req
from lightllm.server.router.req_queue.base_queue import BaseQueue
Expand Down Expand Up @@ -69,6 +70,22 @@ def _can_add_new_group_reqs(self, cur_handle_group_reqs: List[Req], is_busy, new
else:
return False, new_batch_first_router_need_tokens

def _filter_aborted_reqs(self):
# 先移除在等待队列中已经处于aborted状态的请求, 如果发现存在aborted的请求,
# 则休眠10ms,保证httpserver将所有属于一组的请求都置为aborted请求,再将
# 请求从队列中移除。
exist_aborted_req = len([req for req in self.waiting_req_list if req.is_aborted]) > 0
if exist_aborted_req:
time.sleep(0.01)
aborted_reqs = [req for req in self.waiting_req_list if req.is_aborted]
self.waiting_req_list = [req for req in self.waiting_req_list if not req.is_aborted]
for req in aborted_reqs:
req: Req = req
logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}")
self.free_aborted_req_cpu_cache_pages(req)
self.router.shm_req_manager.put_back_req_obj(req)
Comment on lines +77 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

这个方法有两点可以改进:

  1. 性能: 当前实现为了过滤中止的请求,最多会遍历 self.waiting_req_list 三次。这可以优化,通过使用 any() 进行初步检查,然后在单次遍历中对列表进行分区。
  2. 可靠性与性能: 使用阻塞的 time.sleep(0.1) 进行同步是脆弱的,并且存在性能问题。它会暂停线程,增加延迟,并且如果其他请求并发中止,等待时间可能不足以防止竞争条件。一个更健壮的同步机制,如条件变量或事件,将是更好的长期解决方案。

下面的建议代码重构了过滤逻辑以提高性能。time.sleep() 的问题更为根本,应通过更稳健的同步策略来解决。

Suggested change
exist_aborted_req = len([req for req in self.waiting_req_list if req.is_aborted]) > 0
if exist_aborted_req:
time.sleep(0.1)
aborted_reqs = [req for req in self.waiting_req_list if req.is_aborted]
self.waiting_req_list = [req for req in self.waiting_req_list if not req.is_aborted]
for req in aborted_reqs:
req: Req = req
logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}")
self.free_aborted_req_cpu_cache_pages(req)
self.router.shm_req_manager.put_back_req_obj(req)
if any(req.is_aborted for req in self.waiting_req_list):
time.sleep(0.1)
aborted_reqs = []
new_waiting_req_list = []
for req in self.waiting_req_list:
if req.is_aborted:
aborted_reqs.append(req)
else:
new_waiting_req_list.append(req)
self.waiting_req_list = new_waiting_req_list
for req in aborted_reqs:
logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}")
self.free_aborted_req_cpu_cache_pages(req)
self.router.shm_req_manager.put_back_req_obj(req)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

等待队列一般规模不会超过1000,同时生成式写法python执行效率更高,在小数组情况下更适合。

return

# @calculate_time(show=True, min_cost_ms=10)
def generate_new_batch(self, current_batch: Batch):
if len(self.waiting_req_list) == 0:
Expand All @@ -79,6 +96,8 @@ def generate_new_batch(self, current_batch: Batch):
req_is_full = exist_req_num >= self.running_max_req_size
if req_is_full:
return None

self._filter_aborted_reqs()
if len(self.waiting_req_list) == 0:
return None

Expand All @@ -87,15 +106,9 @@ def generate_new_batch(self, current_batch: Batch):

self._init_cache_list(current_batch, is_busy)
can_run_list = []
abort_req_list = []
new_batch_first_router_need_tokens = 0 # 主要是对 prefill 大块计算时候的token数量限制
aborted_count = 0
cur_group_reqs = []
for req in self.waiting_req_list:
if req.is_aborted:
aborted_count += 1
abort_req_list.append(req)
continue

if self._add_to_group(cur_group_reqs, req):
continue
Expand All @@ -121,12 +134,7 @@ def generate_new_batch(self, current_batch: Batch):
if len(can_run_list) != 0:
new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node)

for req in abort_req_list:
req: Req = req
logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}")
self.free_aborted_req_cpu_cache_pages(req)
self.router.shm_req_manager.put_back_req_obj(req)
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
self.waiting_req_list = self.waiting_req_list[len(can_run_list) :]
return new_batch

def _add_to_group(self, cur_group_reqs, req: Req):
Expand Down
Loading