From 627ea215cb6c5fc74232309b79e96bf9ca5f79b0 Mon Sep 17 00:00:00 2001 From: kevincheng2 Date: Tue, 19 May 2026 17:11:09 +0800 Subject: [PATCH] feat(cache): support request-level prefix cache disable --- docs/online_serving/README.md | 3 + docs/zh/online_serving/README.md | 3 + fastdeploy/engine/request.py | 5 ++ fastdeploy/engine/resource_manager.py | 7 +- .../engine/sched/resource_manager_v1.py | 65 ++++++++++++------ fastdeploy/entrypoints/openai/protocol.py | 2 + tests/engine/test_request.py | 34 +++++++++- tests/engine/test_resource_manager.py | 25 ++++++- tests/v1/test_resource_manager_v1.py | 66 +++++++++++++++++++ 9 files changed, 184 insertions(+), 26 deletions(-) diff --git a/docs/online_serving/README.md b/docs/online_serving/README.md index c9dba035339..6faff9d1c1a 100644 --- a/docs/online_serving/README.md +++ b/docs/online_serving/README.md @@ -213,6 +213,9 @@ prompt_token_ids: Optional[List[int]] = None disable_chat_template: Optional[bool] = False # Whether to disable chat template rendering, using raw input directly (default False means template is enabled). +disable_prefix_caching: Optional[bool] = False +# Whether to disable prefix caching for the current request, including skipping prefix cache matching, writing, and release paths for cache reuse (default False means following the global prefix caching configuration). + temp_scaled_logprobs: Optional[bool] = False # Whether to divide the logits by the temperature coefficient when calculating logprobs (default is False, meaning the logits are not divided by the temperature coefficient). diff --git a/docs/zh/online_serving/README.md b/docs/zh/online_serving/README.md index 0264c928bd5..13df887452d 100644 --- a/docs/zh/online_serving/README.md +++ b/docs/zh/online_serving/README.md @@ -209,6 +209,9 @@ prompt_token_ids: Optional[List[int]] = None disable_chat_template: Optional[bool] = False # 是否禁用聊天模板渲染,直接使用原始输入(默认 False 表示启用模板)。 +disable_prefix_caching: Optional[bool] = False +# 是否对当前请求禁用 prefix caching,包括跳过 prefix cache 匹配、写入和释放等复用缓存流程(默认 False 表示按全局 prefix caching 配置执行)。 + temp_scaled_logprobs: Optional[bool] = False # 计算logprob时是否对logits除以温度系数(默认 False 表示不除以温度系数)。 diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 05ee4a348ea..f1b57edfe6e 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -93,6 +93,7 @@ def __init__( multimodal_data: Optional[dict] = None, disable_chat_template: bool = False, disaggregate_info: Optional[dict] = None, + disable_prefix_caching: bool = False, draft_token_ids: Optional[list[int]] = None, guided_json: Optional[Any] = None, guided_regex: Optional[Any] = None, @@ -148,6 +149,7 @@ def __init__( self.num_cached_blocks = 0 self.disable_chat_template = disable_chat_template self.disaggregate_info = disaggregate_info + self.disable_prefix_caching = disable_prefix_caching # speculative method in disaggregate-mode self.draft_token_ids = draft_token_ids @@ -269,6 +271,7 @@ def from_generic_request( metrics=metrics, guided_json_object=guided_json_object, disaggregate_info=getattr(req, "disaggregate_info", None), + disable_prefix_caching=getattr(req, "disable_prefix_caching", False), guided_json=getattr(req, "guided_json", None), guided_regex=getattr(req, "guided_regex", None), guided_choice=getattr(req, "guided_choice", None), @@ -373,6 +376,7 @@ def from_dict(cls, d: dict): multimodal_data=d.get("multimodal_data"), disable_chat_template=d.get("disable_chat_template"), disaggregate_info=d.get("disaggregate_info"), + disable_prefix_caching=d.get("disable_prefix_caching", False), draft_token_ids=d.get("draft_token_ids"), guided_json=d.get("guided_json", None), guided_regex=d.get("guided_regex", None), @@ -445,6 +449,7 @@ def to_dict(self) -> dict: "multimodal_data": self.multimodal_data, "disable_chat_template": self.disable_chat_template, "disaggregate_info": self.disaggregate_info, + "disable_prefix_caching": self.disable_prefix_caching, "draft_token_ids": self.draft_token_ids, "enable_thinking": self.enable_thinking, "reasoning_max_tokens": self.reasoning_max_tokens, diff --git a/fastdeploy/engine/resource_manager.py b/fastdeploy/engine/resource_manager.py index 173cbdf9dd7..47d47fd7637 100644 --- a/fastdeploy/engine/resource_manager.py +++ b/fastdeploy/engine/resource_manager.py @@ -151,6 +151,9 @@ def check_and_free_block_tables(self): if self.available_block_num() < self.cfg.max_block_num_per_seq: self.free_block_tables(self.cfg.max_block_num_per_seq) + def _enable_prefix_cache_for_request(self, task): + return self.enable_prefix_cache and not getattr(task, "disable_prefix_caching", False) + def _recycle_block_tables(self, task): """ Recycling memory resource blocks @@ -159,7 +162,7 @@ def _recycle_block_tables(self, task): block_tables (list): block list """ - if self.enable_prefix_cache: + if self._enable_prefix_cache_for_request(task): self.cache_manager.release_block_ids_async(task) else: req_id = task.request_id @@ -245,7 +248,7 @@ def allocate_resources_for_new_tasks(self, tasks): task.set("seed", random.randint(0, 9223372036854775807)) task.idx = allocated_position - if self.enable_prefix_cache: # if prefix caching is enabled + if self._enable_prefix_cache_for_request(task): # if prefix caching is enabled # 1. request for enough blocks for current task cache_prepare_time = time.time() common_block_ids, unique_block_ids, hit_info = self.cache_manager.request_block_ids( diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index de89ab3adca..161d78afd24 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -241,6 +241,9 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l def allocated_slots(self, request: Request): return len(request.block_tables) * self.config.cache_config.block_size + def _enable_prefix_cache_for_request(self, request: Request): + return self.config.cache_config.enable_prefix_caching and not getattr(request, "disable_prefix_caching", False) + def get_new_block_nums(self, request: Request, num_new_tokens: int): block_num = ( request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1 @@ -405,13 +408,17 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re if preempted_req.request_id in self.req_dict: del self.req_dict[preempted_req.request_id] if envs.FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST: - if self.config.cache_config.kvcache_storage_backend: + if self.config.cache_config.kvcache_storage_backend and self._enable_prefix_cache_for_request( + preempted_req + ): self.cache_manager.write_cache_to_storage_decode(preempted_req) self._free_blocks(preempted_req) llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}") else: if envs.FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST: - if self.config.cache_config.kvcache_storage_backend: + if self.config.cache_config.kvcache_storage_backend and self._enable_prefix_cache_for_request( + preempted_req + ): self.cache_manager.write_cache_to_storage(preempted_req) self._free_blocks(preempted_req) preempted_req.num_cached_blocks = 0 @@ -830,7 +837,7 @@ def add_abort_req_ids(self, req_ids): def cache_output_tokens(self, request): if ( - self.config.cache_config.enable_prefix_caching + self._enable_prefix_cache_for_request(request) and self.config.cache_config.enable_output_caching and self.config.scheduler_config.splitwise_role != "decode" ): @@ -1015,7 +1022,7 @@ def _allocate_decode_and_extend(): token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens if ( - self.config.cache_config.enable_prefix_caching + self._enable_prefix_cache_for_request(request) and self.config.scheduler_config.splitwise_role != "decode" and self.config.scheduler_config.splitwise_role != "prefill" ): @@ -1058,7 +1065,7 @@ def _allocate_decode_and_extend(): self._update_mm_hashes(request) # Enable prefix caching - if self.config.cache_config.enable_prefix_caching: + if self._enable_prefix_cache_for_request(request): if ( self.cache_manager.num_cpu_blocks > 0 or self.config.cache_config.kvcache_storage_backend @@ -1087,7 +1094,7 @@ def _allocate_decode_and_extend(): break num_new_tokens = self._get_num_new_tokens(request, token_budget) if num_new_tokens == 0: - if self.config.cache_config.enable_prefix_caching: + if self._enable_prefix_cache_for_request(request): self._free_blocks(request) skip_requests.append(request) self.waiting.popleft() @@ -1109,7 +1116,7 @@ def _allocate_decode_and_extend(): token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens if ( - self.config.cache_config.enable_prefix_caching + self._enable_prefix_cache_for_request(request) and self.config.scheduler_config.splitwise_role != "decode" ): self.cache_manager.update_cache_blocks( @@ -1124,7 +1131,7 @@ def _allocate_decode_and_extend(): self.req_dict[request.request_id] = allocated_position llm_logger.debug(f"req_id:{request.request_id} allocate pos end") else: - if self.config.cache_config.enable_prefix_caching: + if self._enable_prefix_cache_for_request(request): self._free_blocks(request) break elif request.status == RequestStatus.PREEMPTED: @@ -1132,7 +1139,7 @@ def _allocate_decode_and_extend(): request.num_total_tokens ) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct if ( - self.config.cache_config.enable_prefix_caching + self._enable_prefix_cache_for_request(request) and self.config.scheduler_config.splitwise_role != "decode" ): if ( @@ -1156,7 +1163,7 @@ def _allocate_decode_and_extend(): break num_new_tokens = self._get_num_new_tokens(request, token_budget) if num_new_tokens == 0: - if self.config.cache_config.enable_prefix_caching: + if self._enable_prefix_cache_for_request(request): self._free_blocks(request) skip_requests.append(request) self.waiting.popleft() @@ -1178,7 +1185,7 @@ def _allocate_decode_and_extend(): token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens if ( - self.config.cache_config.enable_prefix_caching + self._enable_prefix_cache_for_request(request) and self.config.scheduler_config.splitwise_role != "decode" ): self.cache_manager.update_cache_blocks( @@ -1186,7 +1193,7 @@ def _allocate_decode_and_extend(): ) request.status = RequestStatus.RUNNING_PREFILL else: - if self.config.cache_config.enable_prefix_caching: + if self._enable_prefix_cache_for_request(request): self._free_blocks(request) break else: @@ -1347,6 +1354,19 @@ def get_prefix_cached_blocks(self, request: Request): """ Match and fetch cache for a task. """ + if not self._enable_prefix_cache_for_request(request): + request.cache_info = [ + 0, + self.cache_manager.get_required_block_num( + request.need_prefill_tokens, + self.config.cache_config.block_size, + ), + ] + request.block_tables = [] + request.num_cached_tokens = 0 + request.num_computed_tokens = 0 + return True + try: trace_print(LoggingEventName.PREPARE_PREFIX_CACHE_START, request.request_id, getattr(request, "user", "")) @@ -1452,7 +1472,7 @@ def preallocate_resource_in_p(self, request: Request): need_prealloc_prefill_blocks = ( request.need_prefill_tokens + self.config.cache_config.block_size - 1 ) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num - if self.config.cache_config.enable_prefix_caching: + if self._enable_prefix_cache_for_request(request): # Enable prefix caching if self.cache_manager.num_cpu_blocks > 0 or self.config.cache_config.kvcache_storage_backend: if not self.cache_manager.can_allocate_gpu_blocks( @@ -1609,7 +1629,7 @@ def _write_prefill_routing_to_host_buffer(self, request, routing_data): self.routing_host_view.scatter(slot_mapping, routing_data) def _free_blocks(self, request: Request): - if self.config.cache_config.enable_prefix_caching and self.config.scheduler_config.splitwise_role != "decode": + if self._enable_prefix_cache_for_request(request) and self.config.scheduler_config.splitwise_role != "decode": self.cache_manager.release_block_ids(request) self.cache_manager.recycle_gpu_blocks( request.block_tables[request.num_cached_blocks :], request.request_id @@ -1672,13 +1692,16 @@ def finish_requests(self, request_ids: Union[str, Iterable[str]]): # Do not block the main thread here # Write cache to storage if kvcache_storage_backend is enabled - for req in need_postprocess_reqs: - if self.config.scheduler_config.splitwise_role == "decode": - # D instance uses simplified write method (does not rely on Radix Tree) - self.cache_manager.write_cache_to_storage_decode(req) - else: - # P instance / Mixed instance uses standard write method (relies on Radix Tree) - self.cache_manager.write_cache_to_storage(req) + if self.config.cache_config.kvcache_storage_backend: + for req in need_postprocess_reqs: + if not self._enable_prefix_cache_for_request(req): + continue + if self.config.scheduler_config.splitwise_role == "decode": + # D instance uses simplified write method (does not rely on Radix Tree) + self.cache_manager.write_cache_to_storage_decode(req) + else: + # P instance / Mixed instance uses standard write method (relies on Radix Tree) + self.cache_manager.write_cache_to_storage(req) with self.lock: for req in need_postprocess_reqs: diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index a546017d30f..8d7dc8fc2db 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -525,6 +525,7 @@ class CompletionRequest(BaseModel): user: Optional[str] = None request_id: Optional[str] = None disaggregate_info: Optional[dict] = None + disable_prefix_caching: Optional[bool] = False # doc: begin-completion-sampling-params top_k: Optional[int] = None @@ -707,6 +708,7 @@ class ChatCompletionRequest(BaseModel): response_format: Optional[AnyResponseFormat] = None request_id: Optional[str] = None disaggregate_info: Optional[dict] = None + disable_prefix_caching: Optional[bool] = False # doc: begin-chat-completion-sampling-params top_k: Optional[int] = None diff --git a/tests/engine/test_request.py b/tests/engine/test_request.py index fd9eab17dc7..c807e3cbd4f 100644 --- a/tests/engine/test_request.py +++ b/tests/engine/test_request.py @@ -32,7 +32,12 @@ SamplingParams, StructuralTagResponseFormat, ) -from fastdeploy.entrypoints.openai.protocol import ResponseFormat, StructuralTag +from fastdeploy.entrypoints.openai.protocol import ( + ChatCompletionRequest, + CompletionRequest, + ResponseFormat, + StructuralTag, +) class TestRequestInit(unittest.TestCase): @@ -60,6 +65,7 @@ def test_init_default_values(self): self.assertEqual(request.num_cached_blocks, 0) self.assertFalse(request.disable_chat_template) self.assertIsNone(request.disaggregate_info) + self.assertFalse(request.disable_prefix_caching) # Test multi-modal defaults self.assertIsNone(request.multimodal_inputs) @@ -102,6 +108,7 @@ def test_init_with_parameters(self): eos_token_ids=[0], disable_chat_template=True, disaggregate_info={"key": "value"}, + disable_prefix_caching=True, draft_token_ids=[4, 5], guided_json={"schema": "test"}, guided_regex="test.*", @@ -152,6 +159,7 @@ def test_init_with_parameters(self): # Test boolean parameters self.assertTrue(request.disable_chat_template) + self.assertTrue(request.disable_prefix_caching) self.assertTrue(request.guided_json_object) self.assertTrue(request.enable_thinking) self.assertTrue(request.add_generation_prompt) @@ -277,6 +285,7 @@ def test_from_generic_request(self): mock_generic_request.prompt_token_ids = [1, 2, 3] mock_generic_request.messages = [{"role": "user", "content": "Hello"}] mock_generic_request.disable_chat_template = True + mock_generic_request.disable_prefix_caching = True mock_generic_request.tools = [Mock()] mock_generic_request.tools[0].model_dump.return_value = {"name": "test_tool"} mock_generic_request.suffix = {"test": "value"} @@ -298,6 +307,7 @@ def test_from_generic_request(self): self.assertEqual(request.prompt_token_ids, [1, 2, 3]) self.assertEqual(request.messages, [{"role": "user", "content": "Hello"}]) self.assertTrue(request.disable_chat_template) + self.assertTrue(request.disable_prefix_caching) self.assertEqual(request.tools, [{"name": "test_tool"}]) self.assertIsInstance(request.metrics, RequestMetrics) @@ -320,6 +330,7 @@ def test_from_dict(self): "multimodal_data": {"images": ["img1"]}, "disable_chat_template": True, "disaggregate_info": {"key": "value"}, + "disable_prefix_caching": True, "draft_token_ids": [4, 5], "guided_json": {"schema": "test"}, "guided_regex": "test.*", @@ -363,6 +374,25 @@ def test_from_dict(self): # Test metrics creation self.assertIsInstance(request.metrics, RequestMetrics) self.assertEqual(request.metrics.arrival_time, 1000.0) + self.assertTrue(request.disable_prefix_caching) + + def test_openai_protocol_disable_prefix_caching(self): + """Test disable_prefix_caching defaults and request dict propagation.""" + completion = CompletionRequest(model="test", prompt="hello") + self.assertFalse(completion.disable_prefix_caching) + self.assertFalse(completion.to_dict_for_infer("req-1")["disable_prefix_caching"]) + + completion = CompletionRequest(model="test", prompt="hello", disable_prefix_caching=True) + self.assertTrue(completion.to_dict_for_infer("req-2")["disable_prefix_caching"]) + + chat = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "hello"}]) + self.assertFalse(chat.disable_prefix_caching) + self.assertFalse(chat.to_dict_for_infer("req-3")["disable_prefix_caching"]) + + chat = ChatCompletionRequest( + model="test", messages=[{"role": "user", "content": "hello"}], disable_prefix_caching=True + ) + self.assertTrue(chat.to_dict_for_infer("req-4")["disable_prefix_caching"]) class TestRequestInstanceMethods(unittest.TestCase): @@ -396,6 +426,7 @@ def test_to_dict_basic(self): request.prompt = "Hello" request.prompt_token_ids = [1, 2, 3] request.prompt_token_ids_len = 3 + request.disable_prefix_caching = True request.sampling_params = SamplingParams() request.metrics = RequestMetrics() request.metrics.prompt_token_ids_len = 3 @@ -406,6 +437,7 @@ def test_to_dict_basic(self): self.assertEqual(data["prompt"], "Hello") self.assertEqual(data["prompt_token_ids"], [1, 2, 3]) self.assertEqual(data["prompt_token_ids_len"], 3) + self.assertTrue(data["disable_prefix_caching"]) def test_to_dict_with_multimodal(self): """Test to_dict with multimodal inputs""" diff --git a/tests/engine/test_resource_manager.py b/tests/engine/test_resource_manager.py index e3eb5942c86..266b740a560 100644 --- a/tests/engine/test_resource_manager.py +++ b/tests/engine/test_resource_manager.py @@ -13,7 +13,7 @@ # limitations under the License. from types import SimpleNamespace -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -57,13 +57,14 @@ def request_block_ids(self, task, block_size, dec_token_num): class _Task: """Real task object with all fields ResourceManager touches.""" - def __init__(self, request_id="req-1", prompt_len=128, disaggregate_info=None): + def __init__(self, request_id="req-1", prompt_len=128, disaggregate_info=None, disable_prefix_caching=False): self.request_id = request_id self.prompt_token_ids = list(range(prompt_len)) self.prompt_token_ids_len = prompt_len self.block_tables = [] self.need_block_tables = [] self.disaggregate_info = disaggregate_info + self.disable_prefix_caching = disable_prefix_caching self.seq_lens_decoder = 0 self.inference_time_cost = -1.0 self.tokens_all_num = 0 @@ -190,6 +191,21 @@ def test_allocate_with_prefix(rm_factory): assert t.cache_info is not None +def test_allocate_with_prefix_disabled_cache(rm_factory): + """Request-level cache bypass uses ordinary allocation when global prefix cache is on.""" + rm = rm_factory(max_seqs=4, enable_prefix=True, dec_token=0, block_size=64, num_free=100) + rm.cache_manager.request_block_ids = MagicMock(wraps=rm.cache_manager.request_block_ids) + t = _Task(prompt_len=256, disable_prefix_caching=True) + result = rm.allocate_resources_for_new_tasks([t]) + assert len(result) == 1 + rm.cache_manager.request_block_ids.assert_not_called() + assert len(t.block_tables) == 4 + assert t.need_block_tables == t.block_tables + assert t.num_cached_tokens == 0 + assert t.cache_info is None + assert t.prompt_token_ids_len == 256 + + def test_allocate_disaggregate(rm_factory): """Disaggregate prefill/decode paths (prefix + no-prefix).""" rm = rm_factory(max_seqs=4, enable_prefix=True, dec_token=0, block_size=64, num_free=100) @@ -217,6 +233,11 @@ def test_recycle_free_and_check(rm_factory): t2.block_tables = [0, 1] rm2._recycle_block_tables(t2) assert t2 in rm2.cache_manager._released + t3 = _Task(disable_prefix_caching=True) + t3.block_tables = [2, 3] + rm2._recycle_block_tables(t3) + assert t3 not in rm2.cache_manager._released + assert 2 in rm2.cache_manager._recycled # free + check paths assert rm.free_block_tables(10) == 10 rm.check_and_free_block_tables() diff --git a/tests/v1/test_resource_manager_v1.py b/tests/v1/test_resource_manager_v1.py index d9ab6a59dbc..9b0c485c803 100644 --- a/tests/v1/test_resource_manager_v1.py +++ b/tests/v1/test_resource_manager_v1.py @@ -547,6 +547,23 @@ def test_get_prefix_cached_blocks_with_revert(self): self.assertEqual(request.metrics.gpu_cache_token_num, 4) self.assertEqual(request.metrics.cpu_cache_token_num, 0) + def test_get_prefix_cached_blocks_disable_prefix_caching(self): + manager = _build_manager(enable_prefix_caching=True) + _register_manager_cleanup(self, manager) + request = _make_request(request_id="req-no-cache", prompt_token_ids=list(range(8))) + request.disable_prefix_caching = True + manager.cache_manager = MagicMock() + manager.cache_manager.get_required_block_num.return_value = 2 + + success = manager.get_prefix_cached_blocks(request) + + self.assertTrue(success) + manager.cache_manager.request_match_blocks.assert_not_called() + self.assertEqual(request.cache_info, [0, 2]) + self.assertEqual(request.block_tables, []) + self.assertEqual(request.num_cached_tokens, 0) + self.assertEqual(request.num_computed_tokens, 0) + def test_preallocate_resource_in_p_and_d(self): manager_p = _build_manager(splitwise_role="prefill", enable_prefix_caching=False) _register_manager_cleanup(self, manager_p) @@ -558,6 +575,18 @@ def test_preallocate_resource_in_p_and_d(self): self.assertEqual(request_p.idx, 0) self.assertFalse(manager_p.stop_flags[0]) + manager_p_with_disabled_cache = _build_manager(splitwise_role="prefill", enable_prefix_caching=True) + _register_manager_cleanup(self, manager_p_with_disabled_cache) + manager_p_with_disabled_cache.cache_manager = MagicMock() + manager_p_with_disabled_cache.cache_manager.can_allocate_gpu_blocks.return_value = True + manager_p_with_disabled_cache.cache_manager.allocate_gpu_blocks.return_value = [1, 2] + request_p_disabled = _make_request(prompt_token_ids=[1, 2, 3]) + request_p_disabled.disable_prefix_caching = True + self.assertTrue(manager_p_with_disabled_cache.preallocate_resource_in_p(request_p_disabled)) + manager_p_with_disabled_cache.cache_manager.request_match_blocks.assert_not_called() + manager_p_with_disabled_cache.cache_manager.update_cache_blocks.assert_not_called() + self.assertEqual(request_p_disabled.num_computed_tokens, 0) + manager_d = _build_manager(splitwise_role="decode", enable_prefix_caching=False) _register_manager_cleanup(self, manager_d) manager_d.cache_manager = MagicMock() @@ -614,6 +643,21 @@ def test_free_blocks_with_extend_tables(self): self.assertEqual(request.block_tables, []) self.assertEqual(request.extend_block_tables, []) + def test_free_blocks_disable_prefix_caching(self): + manager = _build_manager(enable_prefix_caching=True) + _register_manager_cleanup(self, manager) + manager.cache_manager = MagicMock() + request = _make_request(request_id="req-disabled-free") + request.disable_prefix_caching = True + request.block_tables = [1, 2, 3] + request.num_cached_blocks = 1 + + manager._free_blocks(request) + + manager.cache_manager.release_block_ids.assert_not_called() + manager.cache_manager.recycle_gpu_blocks.assert_called_once_with([1, 2, 3], request.request_id) + self.assertEqual(request.block_tables, []) + def test_finish_requests_updates_state(self): manager = _build_manager() _register_manager_cleanup(self, manager) @@ -621,6 +665,7 @@ def test_finish_requests_updates_state(self): manager.cache_manager.num_gpu_blocks = 8 manager.cache_manager.gpu_free_block_list = list(range(8)) manager.cache_manager.write_cache_to_storage = MagicMock() + manager.config.cache_config.kvcache_storage_backend = "mock_backend" request = _make_request(request_id="req-finish") request.idx = 0 manager.tasks_list[0] = request @@ -637,6 +682,27 @@ def test_finish_requests_updates_state(self): manager.cache_manager.write_cache_to_storage.assert_called_once_with(request) manager._free_blocks.assert_called_once_with(request) + def test_finish_requests_disable_prefix_caching_skips_storage_write(self): + manager = _build_manager() + _register_manager_cleanup(self, manager) + manager.cache_manager = MagicMock() + manager.cache_manager.num_gpu_blocks = 8 + manager.cache_manager.gpu_free_block_list = list(range(8)) + request = _make_request(request_id="req-finish-disabled") + request.disable_prefix_caching = True + request.idx = 0 + manager.tasks_list[0] = request + manager.stop_flags[0] = False + manager.requests[request.request_id] = request + manager.running.append(request) + manager._free_blocks = MagicMock() + + manager.finish_requests([request.request_id]) + + manager.cache_manager.write_cache_to_storage.assert_not_called() + manager.cache_manager.write_cache_to_storage_decode.assert_not_called() + manager._free_blocks.assert_called_once_with(request) + def test_schedule_decode_and_waiting_prefill(self): manager = _build_manager(enable_prefix_caching=False) _register_manager_cleanup(self, manager)