diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 08c8dea003a..4a188edd161 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -620,6 +620,8 @@ def __init__( dict() for _ in range(512) ] # {'layer_id': {'prefilled_layer_idx': xx, 'prefilled_block_num': xx}} self.idx_cache_task_dict = {} # {'slot_idx': cache_info_dict} + self.pending_layer0_signals = {} + self.pending_layer0_signal_lock = threading.Lock() self.cache_prefilled_engine_ids_queue = ( queue.Queue() ) # [(slot_idx1, prefilled_token_num1), (slot_idx2, prefilled_token_num2)] @@ -663,7 +665,28 @@ def _add_cache_task_thread(self): current_info["status"] = "init" logger.info(f"Get cache info and finish add cache task: {current_info}") self.cache_info[info["request_id"]] = current_info - self.idx_cache_task_dict[current_info["current_id"]] = current_info + current_id = current_info["current_id"] + with self.engine_cache_task_thread_lock: + self.idx_cache_task_dict[current_id] = current_info + with self.pending_layer0_signal_lock: + recovered_signal = self.pending_layer0_signals.pop(current_id, None) + if recovered_signal is not None: + _, prefilled_token_num = recovered_signal + if prefilled_token_num <= current_info["need_prefill_tokens"]: + recovered_signal_batch = [recovered_signal] + logger.info( + "cache_task_register_recover_layer0_signal: " + f"current_id: {current_id}, " + f"recovered_signal_batch: {recovered_signal_batch}" + ) + self.cache_prefilled_engine_ids_queue.put(recovered_signal_batch) + else: + logger.info( + "cache_task_register_drop_layer0_signal: " + f"current_id: {current_id}, " + f"recovered_signal: {recovered_signal}, " + f"need_prefill_tokens: {current_info['need_prefill_tokens']}" + ) else: logger.info(f"Get cache info: {info}") self.cache_info[info["request_id"]] = info @@ -842,9 +865,12 @@ def prefill_layerwise_send_cache_thread(self): logger.info( f"Put successful cache writing task in engine worker queue, req_id: {task['request_id']}, status: {task['status']}" ) - self.engine_cache_tasks[task["current_id"]] = dict() + current_id = task["current_id"] + self.engine_cache_tasks[current_id] = dict() del self.cache_info[task["request_id"]] - del self.idx_cache_task_dict[task["current_id"]] + del self.idx_cache_task_dict[current_id] + with self.pending_layer0_signal_lock: + self.pending_layer0_signals.pop(current_id, None) break except Exception as e: logger.error(f"prefill layerwise send cache thread has exception: {e} {traceback.format_exc()!s}") @@ -856,32 +882,42 @@ def consume_signals(self): while True: try: get_output_kv_signal(kv_signal_data, self.rank_id, 1) # wait_flag - if not self.cache_info: - time.sleep(0.01) - continue - tasks_count = kv_signal_data[0] + has_cache_info = bool(self.cache_info) + tasks_count = kv_signal_data[0].item() if tasks_count == -1: continue + if not has_cache_info: + logger.debug("consume_signals get kv signal before cache info is ready") layer_id = kv_signal_data[1].item() if layer_id == self.num_layers - 1: logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id} self.rank_id {self.rank_id}") - batch_engine_signals = [] + ready_engine_signals = [] + pending_engine_signals = [] # format for signal to put in cache_prefilled_engine_ids_queue: [(engine_idx1, prefilled_token_num1), (engine_idx2, prefilled_token_num2)] with self.engine_cache_task_thread_lock: for bi in range(tasks_count): engine_idx = kv_signal_data[3 * bi + 2].item() chuck_token_offset = kv_signal_data[3 * bi + 3].item() current_seq_len = kv_signal_data[3 * bi + 4].item() + prefilled_token_num = chuck_token_offset + current_seq_len self.engine_cache_tasks[engine_idx]["prefilled_layer_idx"] = layer_id - self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = ( - chuck_token_offset + current_seq_len - ) - batch_engine_signals.append((engine_idx, chuck_token_offset + current_seq_len)) - if layer_id == 0: - logger.info( - f"Put batch_engine_signals {batch_engine_signals} into cache_prefilled_engine_ids_queue" - ) - self.cache_prefilled_engine_ids_queue.put(batch_engine_signals) + self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = prefilled_token_num + if layer_id == 0: + if engine_idx in self.idx_cache_task_dict: + ready_engine_signals.append((engine_idx, prefilled_token_num)) + else: + pending_engine_signals.append((engine_idx, prefilled_token_num)) + if pending_engine_signals: + with self.pending_layer0_signal_lock: + for engine_idx, prefilled_token_num in pending_engine_signals: + self.pending_layer0_signals[engine_idx] = (engine_idx, prefilled_token_num) + if pending_engine_signals: + logger.debug(f"cache_task_pending_layer0_signal: {pending_engine_signals}") + if ready_engine_signals: + logger.info( + f"Put batch_engine_signals {ready_engine_signals} into cache_prefilled_engine_ids_queue" + ) + self.cache_prefilled_engine_ids_queue.put(ready_engine_signals) except Exception as e: logger.error(f"Consume signals get exception: {e}") diff --git a/tests/cache_manager/test_cache_messager.py b/tests/cache_manager/test_cache_messager.py index 3e415ebe9c8..07ff5054f2a 100644 --- a/tests/cache_manager/test_cache_messager.py +++ b/tests/cache_manager/test_cache_messager.py @@ -124,6 +124,14 @@ def error(self, msg): self.messages.append(("error", msg)) +class _QueueRecorder: + def __init__(self): + self.items = [] + + def put(self, item): + self.items.append(item) + + class _DummySignalValue: def __init__(self, sequence): self.sequence = list(sequence) @@ -390,6 +398,111 @@ def test_cache_messager_v1_add_cache_task_thread(monkeypatch): assert messager.cache_info["req-2"]["status"] == "init" +def test_cache_messager_v1_recovers_pending_layer0_signal(monkeypatch): + dummy_queue = _DummyEngineWorkerQueue( + cache_info_sequence=[ + [ + { + "request_id": "req-pending", + "src_block_ids": [0, 1], + "dest_block_ids": [2], + "current_id": 3, + "need_prefill_tokens": 128, + "transfer_protocol": "rdma", + } + ] + ] + ) + monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue) + monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager) + monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False) + + gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=True, num_layers=1) + messager = cache_messager.CacheMessagerV1( + splitwise_role="mixed", + transfer_protocol="rdma", + pod_ip="0.0.0.0", + engine_worker_queue_port=9000, + local_data_parallel_id=0, + gpu_cache_kvs=gpu_cache_kvs, + rank=0, + nranks=1, + num_layers=1, + gpu_id=0, + block_size=64, + rdma_port="2222", + ) + messager.cache_prefilled_engine_ids_queue = _QueueRecorder() + messager.cache_info["req-pending"] = { + "request_id": "req-pending", + "src_block_ids": [0, 1], + "dest_block_ids": [2], + "current_id": 3, + "need_prefill_tokens": 128, + "transfer_protocol": "rdma", + } + messager.pending_layer0_signals[3] = (3, 64) + messager.pending_layer0_signals[4] = (4, 64) + + with pytest.raises(SystemExit): + messager._add_cache_task_thread() + + assert messager.pending_layer0_signals == {4: (4, 64)} + assert messager.cache_prefilled_engine_ids_queue.items == [[(3, 64)]] + + +def test_cache_messager_v1_drops_invalid_pending_layer0_signal(monkeypatch): + dummy_queue = _DummyEngineWorkerQueue( + cache_info_sequence=[ + [ + { + "request_id": "req-pending", + "src_block_ids": [0, 1], + "dest_block_ids": [2], + "current_id": 3, + "need_prefill_tokens": 128, + "transfer_protocol": "rdma", + } + ] + ] + ) + monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue) + monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager) + monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False) + + gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=True, num_layers=1) + messager = cache_messager.CacheMessagerV1( + splitwise_role="mixed", + transfer_protocol="rdma", + pod_ip="0.0.0.0", + engine_worker_queue_port=9000, + local_data_parallel_id=0, + gpu_cache_kvs=gpu_cache_kvs, + rank=0, + nranks=1, + num_layers=1, + gpu_id=0, + block_size=64, + rdma_port="2222", + ) + messager.cache_prefilled_engine_ids_queue = _QueueRecorder() + messager.cache_info["req-pending"] = { + "request_id": "req-pending", + "src_block_ids": [0, 1], + "dest_block_ids": [2], + "current_id": 3, + "need_prefill_tokens": 128, + "transfer_protocol": "rdma", + } + messager.pending_layer0_signals[3] = (3, 256) + + with pytest.raises(SystemExit): + messager._add_cache_task_thread() + + assert messager.pending_layer0_signals == {} + assert messager.cache_prefilled_engine_ids_queue.items == [] + + def test_cache_messager_v1_prefill_layerwise_send_cache_thread(monkeypatch): class _OneShotQueue: def __init__(self): @@ -435,10 +548,12 @@ def get(self): } messager.engine_cache_tasks[0] = {"prefilled_layer_idx": 1, "prefilled_token_num": 64} messager.cache_info["req-3"] = messager.idx_cache_task_dict[0] + messager.pending_layer0_signals = {0: (0, 64), 1: (1, 64)} with pytest.raises(SystemExit): messager.prefill_layerwise_send_cache_thread() assert dummy_queue.finished_req_payloads assert dummy_queue.finished_req_payloads[0][0][0] == "req-3" + assert messager.pending_layer0_signals == {1: (1, 64)} def test_cache_messager_v1_handle_connect_task(monkeypatch): @@ -562,13 +677,6 @@ def test_cache_messager_v1_consume_signals(monkeypatch): monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager) monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False) - class _QueueRecorder: - def __init__(self): - self.items = [] - - def put(self, item): - self.items.append(item) - counter = {"calls": 0} def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag): @@ -600,12 +708,57 @@ def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag): rdma_port="2222", ) messager.cache_info["req-4"] = {"request_id": "req-4"} + messager.idx_cache_task_dict[2] = {"request_id": "req-4", "current_id": 2} messager.cache_prefilled_engine_ids_queue = _QueueRecorder() with pytest.raises(SystemExit): messager.consume_signals() assert messager.cache_prefilled_engine_ids_queue.items == [[(2, 9)]] +def test_cache_messager_v1_consume_signals_buffers_early_layer0(monkeypatch): + monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue) + monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager) + monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False) + + signals = [(5, 7, 9), (5, 17, 19)] + + def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag): + if not signals: + raise SystemExit + engine_idx, chuck_token_offset, current_seq_len = signals.pop(0) + data = np.full(kv_signal_data.shape, -1, dtype="int32") + data[0] = 1 + data[1] = 0 + data[2] = engine_idx + data[3] = chuck_token_offset + data[4] = current_seq_len + kv_signal_data.set_value(data) + + monkeypatch.setattr(cache_messager, "get_output_kv_signal", _fake_get_output_kv_signal) + gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=False, num_layers=1) + messager = cache_messager.CacheMessagerV1( + splitwise_role="mixed", + transfer_protocol="rdma", + pod_ip="0.0.0.0", + engine_worker_queue_port=9000, + local_data_parallel_id=0, + gpu_cache_kvs=gpu_cache_kvs, + rank=0, + nranks=1, + num_layers=1, + gpu_id=0, + block_size=64, + rdma_port="2222", + ) + messager.cache_prefilled_engine_ids_queue = _QueueRecorder() + + with pytest.raises(SystemExit): + messager.consume_signals() + + assert messager.pending_layer0_signals == {5: (5, 36)} + assert messager.cache_prefilled_engine_ids_queue.items == [] + + def test_main_initializes_cache_and_exits(monkeypatch): monkeypatch.setattr(cache_messager, "set_device", lambda device: None) monkeypatch.setattr(cache_messager, "set_data_ipc", lambda tensor, name: None)