Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/online_serving/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ prompt_token_ids: Optional[List[int]] = None
disable_chat_template: Optional[bool] = False
# Whether to disable chat template rendering, using raw input directly (default False means template is enabled).

disable_prefix_caching: Optional[bool] = False
# Whether to disable prefix caching for the current request, including skipping prefix cache matching, writing, and release paths for cache reuse (default False means following the global prefix caching configuration).

temp_scaled_logprobs: Optional[bool] = False
# Whether to divide the logits by the temperature coefficient when calculating logprobs (default is False, meaning the logits are not divided by the temperature coefficient).

Expand Down
3 changes: 3 additions & 0 deletions docs/zh/online_serving/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ prompt_token_ids: Optional[List[int]] = None
disable_chat_template: Optional[bool] = False
# 是否禁用聊天模板渲染,直接使用原始输入(默认 False 表示启用模板)。

disable_prefix_caching: Optional[bool] = False
# 是否对当前请求禁用 prefix caching,包括跳过 prefix cache 匹配、写入和释放等复用缓存流程(默认 False 表示按全局 prefix caching 配置执行)。

temp_scaled_logprobs: Optional[bool] = False
# 计算logprob时是否对logits除以温度系数(默认 False 表示不除以温度系数)。

Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
multimodal_data: Optional[dict] = None,
disable_chat_template: bool = False,
disaggregate_info: Optional[dict] = None,
disable_prefix_caching: bool = False,
draft_token_ids: Optional[list[int]] = None,
guided_json: Optional[Any] = None,
guided_regex: Optional[Any] = None,
Expand Down Expand Up @@ -148,6 +149,7 @@ def __init__(
self.num_cached_blocks = 0
self.disable_chat_template = disable_chat_template
self.disaggregate_info = disaggregate_info
self.disable_prefix_caching = disable_prefix_caching

# speculative method in disaggregate-mode
self.draft_token_ids = draft_token_ids
Expand Down Expand Up @@ -269,6 +271,7 @@ def from_generic_request(
metrics=metrics,
guided_json_object=guided_json_object,
disaggregate_info=getattr(req, "disaggregate_info", None),
disable_prefix_caching=getattr(req, "disable_prefix_caching", False),
guided_json=getattr(req, "guided_json", None),
guided_regex=getattr(req, "guided_regex", None),
guided_choice=getattr(req, "guided_choice", None),
Expand Down Expand Up @@ -373,6 +376,7 @@ def from_dict(cls, d: dict):
multimodal_data=d.get("multimodal_data"),
disable_chat_template=d.get("disable_chat_template"),
disaggregate_info=d.get("disaggregate_info"),
disable_prefix_caching=d.get("disable_prefix_caching", False),
draft_token_ids=d.get("draft_token_ids"),
guided_json=d.get("guided_json", None),
guided_regex=d.get("guided_regex", None),
Expand Down Expand Up @@ -445,6 +449,7 @@ def to_dict(self) -> dict:
"multimodal_data": self.multimodal_data,
"disable_chat_template": self.disable_chat_template,
"disaggregate_info": self.disaggregate_info,
"disable_prefix_caching": self.disable_prefix_caching,
"draft_token_ids": self.draft_token_ids,
"enable_thinking": self.enable_thinking,
"reasoning_max_tokens": self.reasoning_max_tokens,
Expand Down
7 changes: 5 additions & 2 deletions fastdeploy/engine/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ def check_and_free_block_tables(self):
if self.available_block_num() < self.cfg.max_block_num_per_seq:
self.free_block_tables(self.cfg.max_block_num_per_seq)

def _enable_prefix_cache_for_request(self, task):
return self.enable_prefix_cache and not getattr(task, "disable_prefix_caching", False)

def _recycle_block_tables(self, task):
"""
Recycling memory resource blocks
Expand All @@ -159,7 +162,7 @@ def _recycle_block_tables(self, task):
block_tables (list): block list
"""

if self.enable_prefix_cache:
if self._enable_prefix_cache_for_request(task):
self.cache_manager.release_block_ids_async(task)
else:
req_id = task.request_id
Expand Down Expand Up @@ -245,7 +248,7 @@ def allocate_resources_for_new_tasks(self, tasks):
task.set("seed", random.randint(0, 9223372036854775807))
task.idx = allocated_position

if self.enable_prefix_cache: # if prefix caching is enabled
if self._enable_prefix_cache_for_request(task): # if prefix caching is enabled
# 1. request for enough blocks for current task
cache_prepare_time = time.time()
common_block_ids, unique_block_ids, hit_info = self.cache_manager.request_block_ids(
Expand Down
65 changes: 44 additions & 21 deletions fastdeploy/engine/sched/resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l
def allocated_slots(self, request: Request):
return len(request.block_tables) * self.config.cache_config.block_size

def _enable_prefix_cache_for_request(self, request: Request):
return self.config.cache_config.enable_prefix_caching and not getattr(request, "disable_prefix_caching", False)

def get_new_block_nums(self, request: Request, num_new_tokens: int):
block_num = (
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
Expand Down Expand Up @@ -405,13 +408,17 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re
if preempted_req.request_id in self.req_dict:
del self.req_dict[preempted_req.request_id]
if envs.FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST:
if self.config.cache_config.kvcache_storage_backend:
if self.config.cache_config.kvcache_storage_backend and self._enable_prefix_cache_for_request(
preempted_req
):
self.cache_manager.write_cache_to_storage_decode(preempted_req)
self._free_blocks(preempted_req)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
else:
if envs.FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST:
if self.config.cache_config.kvcache_storage_backend:
if self.config.cache_config.kvcache_storage_backend and self._enable_prefix_cache_for_request(
preempted_req
):
self.cache_manager.write_cache_to_storage(preempted_req)
self._free_blocks(preempted_req)
preempted_req.num_cached_blocks = 0
Expand Down Expand Up @@ -830,7 +837,7 @@ def add_abort_req_ids(self, req_ids):

def cache_output_tokens(self, request):
if (
self.config.cache_config.enable_prefix_caching
self._enable_prefix_cache_for_request(request)
and self.config.cache_config.enable_output_caching
and self.config.scheduler_config.splitwise_role != "decode"
):
Expand Down Expand Up @@ -1015,7 +1022,7 @@ def _allocate_decode_and_extend():
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
if (
self.config.cache_config.enable_prefix_caching
self._enable_prefix_cache_for_request(request)
and self.config.scheduler_config.splitwise_role != "decode"
and self.config.scheduler_config.splitwise_role != "prefill"
):
Expand Down Expand Up @@ -1058,7 +1065,7 @@ def _allocate_decode_and_extend():

self._update_mm_hashes(request)
# Enable prefix caching
if self.config.cache_config.enable_prefix_caching:
if self._enable_prefix_cache_for_request(request):
if (
self.cache_manager.num_cpu_blocks > 0
or self.config.cache_config.kvcache_storage_backend
Expand Down Expand Up @@ -1087,7 +1094,7 @@ def _allocate_decode_and_extend():
break
num_new_tokens = self._get_num_new_tokens(request, token_budget)
if num_new_tokens == 0:
if self.config.cache_config.enable_prefix_caching:
if self._enable_prefix_cache_for_request(request):
self._free_blocks(request)
skip_requests.append(request)
self.waiting.popleft()
Expand All @@ -1109,7 +1116,7 @@ def _allocate_decode_and_extend():
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
if (
self.config.cache_config.enable_prefix_caching
self._enable_prefix_cache_for_request(request)
and self.config.scheduler_config.splitwise_role != "decode"
):
self.cache_manager.update_cache_blocks(
Expand All @@ -1124,15 +1131,15 @@ def _allocate_decode_and_extend():
self.req_dict[request.request_id] = allocated_position
llm_logger.debug(f"req_id:{request.request_id} allocate pos end")
else:
if self.config.cache_config.enable_prefix_caching:
if self._enable_prefix_cache_for_request(request):
self._free_blocks(request)
break
elif request.status == RequestStatus.PREEMPTED:
request.need_prefill_tokens = (
request.num_total_tokens
) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
if (
self.config.cache_config.enable_prefix_caching
self._enable_prefix_cache_for_request(request)
and self.config.scheduler_config.splitwise_role != "decode"
):
if (
Expand All @@ -1156,7 +1163,7 @@ def _allocate_decode_and_extend():
break
num_new_tokens = self._get_num_new_tokens(request, token_budget)
if num_new_tokens == 0:
if self.config.cache_config.enable_prefix_caching:
if self._enable_prefix_cache_for_request(request):
self._free_blocks(request)
skip_requests.append(request)
self.waiting.popleft()
Expand All @@ -1178,15 +1185,15 @@ def _allocate_decode_and_extend():
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
if (
self.config.cache_config.enable_prefix_caching
self._enable_prefix_cache_for_request(request)
and self.config.scheduler_config.splitwise_role != "decode"
):
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
)
request.status = RequestStatus.RUNNING_PREFILL
else:
if self.config.cache_config.enable_prefix_caching:
if self._enable_prefix_cache_for_request(request):
self._free_blocks(request)
break
else:
Expand Down Expand Up @@ -1347,6 +1354,19 @@ def get_prefix_cached_blocks(self, request: Request):
"""
Match and fetch cache for a task.
"""
if not self._enable_prefix_cache_for_request(request):
request.cache_info = [
0,
self.cache_manager.get_required_block_num(
request.need_prefill_tokens,
self.config.cache_config.block_size,
),
]
request.block_tables = []
request.num_cached_tokens = 0
request.num_computed_tokens = 0
return True

try:
trace_print(LoggingEventName.PREPARE_PREFIX_CACHE_START, request.request_id, getattr(request, "user", ""))

Expand Down Expand Up @@ -1452,7 +1472,7 @@ def preallocate_resource_in_p(self, request: Request):
need_prealloc_prefill_blocks = (
request.need_prefill_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num
if self.config.cache_config.enable_prefix_caching:
if self._enable_prefix_cache_for_request(request):
# Enable prefix caching
if self.cache_manager.num_cpu_blocks > 0 or self.config.cache_config.kvcache_storage_backend:
if not self.cache_manager.can_allocate_gpu_blocks(
Expand Down Expand Up @@ -1609,7 +1629,7 @@ def _write_prefill_routing_to_host_buffer(self, request, routing_data):
self.routing_host_view.scatter(slot_mapping, routing_data)

def _free_blocks(self, request: Request):
if self.config.cache_config.enable_prefix_caching and self.config.scheduler_config.splitwise_role != "decode":
if self._enable_prefix_cache_for_request(request) and self.config.scheduler_config.splitwise_role != "decode":
self.cache_manager.release_block_ids(request)
self.cache_manager.recycle_gpu_blocks(
request.block_tables[request.num_cached_blocks :], request.request_id
Expand Down Expand Up @@ -1672,13 +1692,16 @@ def finish_requests(self, request_ids: Union[str, Iterable[str]]):

# Do not block the main thread here
# Write cache to storage if kvcache_storage_backend is enabled
for req in need_postprocess_reqs:
if self.config.scheduler_config.splitwise_role == "decode":
# D instance uses simplified write method (does not rely on Radix Tree)
self.cache_manager.write_cache_to_storage_decode(req)
else:
# P instance / Mixed instance uses standard write method (relies on Radix Tree)
self.cache_manager.write_cache_to_storage(req)
if self.config.cache_config.kvcache_storage_backend:
for req in need_postprocess_reqs:
if not self._enable_prefix_cache_for_request(req):
continue
if self.config.scheduler_config.splitwise_role == "decode":
# D instance uses simplified write method (does not rely on Radix Tree)
self.cache_manager.write_cache_to_storage_decode(req)
else:
# P instance / Mixed instance uses standard write method (relies on Radix Tree)
self.cache_manager.write_cache_to_storage(req)

with self.lock:
for req in need_postprocess_reqs:
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ class CompletionRequest(BaseModel):
user: Optional[str] = None
request_id: Optional[str] = None
disaggregate_info: Optional[dict] = None
disable_prefix_caching: Optional[bool] = False

# doc: begin-completion-sampling-params
top_k: Optional[int] = None
Expand Down Expand Up @@ -707,6 +708,7 @@ class ChatCompletionRequest(BaseModel):
response_format: Optional[AnyResponseFormat] = None
request_id: Optional[str] = None
disaggregate_info: Optional[dict] = None
disable_prefix_caching: Optional[bool] = False

# doc: begin-chat-completion-sampling-params
top_k: Optional[int] = None
Expand Down
Loading
Loading