Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,18 +1026,17 @@ def _fetch_request():
break
else:
raise
if self.cfg.scheduler_config.splitwise_role != "mixed":
# Continue preprocessing incoming requests and accumulating them in the queue when forward pass not finished.
# Once the forward pass finishes, these accumulated requests can be scheduled in larger,
# more efficient batches.
if self.engine_worker_queue.exist_tasks() or self.engine_forward_signal.value[0] != 0:
time.sleep(0.001)
continue
else:
# In mixed, todo: optimze cache swap, to decouple swap from scheduler
if self.engine_worker_queue.exist_tasks():
time.sleep(0.001)
continue
# Continue preprocessing incoming requests and accumulating them in the queue when forward pass not finished.
# Once the forward pass finishes, these accumulated requests can be scheduled in larger,
# more efficient batches.
if (self.engine_worker_queue.exist_tasks() or self.engine_forward_signal.value[0] != 0) and (
self.resource_manager.cache_manager.num_cpu_blocks == 0
or self.cfg.scheduler_config.splitwise_role != "mixed"
):
# In mixed, when cpu cache is enabled, cache swapping happens in schedule(). If we call schedule() after forward, ttft will degradation due to non-overlaped swapping.
# todo: support cache swapping before inserting into waiting list.
time.sleep(0.001)
continue

if hasattr(self.resource_manager, "scheduler_unhandled_request_num"):
self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num()
Expand Down
75 changes: 42 additions & 33 deletions fastdeploy/scheduler/local_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,43 +258,52 @@ def get_requests(
f"max_num_batched_tokens={max_num_batched_tokens}"
)
return []

requests: List[Request] = []
with self.requests_not_empty:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
self.wait_request_timeout,
)

requests: List[Request] = []
required_total_blocks = 0
current_prefill_tokens = 0
long_partial_requests, short_partial_requests = 0, 0
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
break

if not envs.FD_ENABLE_MAX_PREFILL:
if self.enable_chunked_prefill:
if request.prompt_tokens_ids_len > self.long_prefill_token_threshold:
# 长请求
long_partial_requests += 1
if long_partial_requests > self.max_long_partial_prefills:
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + 1],
0.005,
)
if batch_ids:
for request_id in batch_ids:
request = self.requests[request_id]
requests.append(request.raw)
self.ids_read_cursor += 1
else:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
self.wait_request_timeout,
)
required_total_blocks = 0
current_prefill_tokens = 0
long_partial_requests, short_partial_requests = 0, 0
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
break

if not envs.FD_ENABLE_MAX_PREFILL:
if self.enable_chunked_prefill:
if request.prompt_tokens_ids_len > self.long_prefill_token_threshold:
# 长请求
long_partial_requests += 1
if long_partial_requests > self.max_long_partial_prefills:
break
else:
short_partial_requests += 1

if short_partial_requests + long_partial_requests > self.max_num_partial_prefills:
break
else:
short_partial_requests += 1

if short_partial_requests + long_partial_requests > self.max_num_partial_prefills:
break
else:
if current_prefill_tokens > max_num_batched_tokens and len(requests) > 0:
break
requests.append(request.raw)
if current_prefill_tokens > max_num_batched_tokens and len(requests) > 0:
break
requests.append(request.raw)

self.ids_read_cursor += len(requests)
self.ids_read_cursor += len(requests)

if len(batch_ids) > 0 and len(requests) == 0:
scheduler_logger.debug(f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}")
Expand Down
Loading