diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index d0f39ee8e73..bd6628d6702 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1026,18 +1026,17 @@ def _fetch_request(): break else: raise - if self.cfg.scheduler_config.splitwise_role != "mixed": - # Continue preprocessing incoming requests and accumulating them in the queue when forward pass not finished. - # Once the forward pass finishes, these accumulated requests can be scheduled in larger, - # more efficient batches. - if self.engine_worker_queue.exist_tasks() or self.engine_forward_signal.value[0] != 0: - time.sleep(0.001) - continue - else: - # In mixed, todo: optimze cache swap, to decouple swap from scheduler - if self.engine_worker_queue.exist_tasks(): - time.sleep(0.001) - continue + # Continue preprocessing incoming requests and accumulating them in the queue when forward pass not finished. + # Once the forward pass finishes, these accumulated requests can be scheduled in larger, + # more efficient batches. + if (self.engine_worker_queue.exist_tasks() or self.engine_forward_signal.value[0] != 0) and ( + self.resource_manager.cache_manager.num_cpu_blocks == 0 + or self.cfg.scheduler_config.splitwise_role != "mixed" + ): + # In mixed, when cpu cache is enabled, cache swapping happens in schedule(). If we call schedule() after forward, ttft will degradation due to non-overlaped swapping. + # todo: support cache swapping before inserting into waiting list. + time.sleep(0.001) + continue if hasattr(self.resource_manager, "scheduler_unhandled_request_num"): self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num() diff --git a/fastdeploy/scheduler/local_scheduler.py b/fastdeploy/scheduler/local_scheduler.py index fc4a64686b5..db202a16d53 100644 --- a/fastdeploy/scheduler/local_scheduler.py +++ b/fastdeploy/scheduler/local_scheduler.py @@ -258,43 +258,52 @@ def get_requests( f"max_num_batched_tokens={max_num_batched_tokens}" ) return [] - + requests: List[Request] = [] with self.requests_not_empty: - batch_ids = self.requests_not_empty.wait_for( - lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch], - self.wait_request_timeout, - ) - - requests: List[Request] = [] - required_total_blocks = 0 - current_prefill_tokens = 0 - long_partial_requests, short_partial_requests = 0, 0 - for request_id in batch_ids: - request = self.requests[request_id] - required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size) - current_prefill_tokens += request.prompt_tokens_ids_len - required_total_blocks += required_input_blocks + reserved_output_blocks - if required_total_blocks > available_blocks: - break - - if not envs.FD_ENABLE_MAX_PREFILL: - if self.enable_chunked_prefill: - if request.prompt_tokens_ids_len > self.long_prefill_token_threshold: - # 长请求 - long_partial_requests += 1 - if long_partial_requests > self.max_long_partial_prefills: + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + batch_ids = self.requests_not_empty.wait_for( + lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + 1], + 0.005, + ) + if batch_ids: + for request_id in batch_ids: + request = self.requests[request_id] + requests.append(request.raw) + self.ids_read_cursor += 1 + else: + batch_ids = self.requests_not_empty.wait_for( + lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch], + self.wait_request_timeout, + ) + required_total_blocks = 0 + current_prefill_tokens = 0 + long_partial_requests, short_partial_requests = 0, 0 + for request_id in batch_ids: + request = self.requests[request_id] + required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size) + current_prefill_tokens += request.prompt_tokens_ids_len + required_total_blocks += required_input_blocks + reserved_output_blocks + if required_total_blocks > available_blocks: + break + + if not envs.FD_ENABLE_MAX_PREFILL: + if self.enable_chunked_prefill: + if request.prompt_tokens_ids_len > self.long_prefill_token_threshold: + # 长请求 + long_partial_requests += 1 + if long_partial_requests > self.max_long_partial_prefills: + break + else: + short_partial_requests += 1 + + if short_partial_requests + long_partial_requests > self.max_num_partial_prefills: break else: - short_partial_requests += 1 - - if short_partial_requests + long_partial_requests > self.max_num_partial_prefills: - break - else: - if current_prefill_tokens > max_num_batched_tokens and len(requests) > 0: - break - requests.append(request.raw) + if current_prefill_tokens > max_num_batched_tokens and len(requests) > 0: + break + requests.append(request.raw) - self.ids_read_cursor += len(requests) + self.ids_read_cursor += len(requests) if len(batch_ids) > 0 and len(requests) == 0: scheduler_logger.debug(f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}")