Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions fastdeploy/cache_manager/cache_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,14 @@ def _handle_resume(self):
logger.info("✅ Successfully resumed transfer")
return True

def _handle_update_weights(self):
if self.storage_backend_type is not None:
self._update_key_prefix()
logger.info("✅ Successfully updated cache key prefix after weight update")
else:
logger.info("💡 Cache storage backend is disabled, skip updating cache key prefix")
return True

def _handle_sleep(self):
if self.is_sleeping:
logger.info("💡 Cache transfer manager is already sleeping, no need to sleep again!")
Expand Down Expand Up @@ -1128,6 +1136,7 @@ def control_task(self, task: ControlRequest):
handlers = {
"pause": self._handle_pause,
"resume": self._handle_resume,
"update_weights": self._handle_update_weights,
"sleep": self._handle_sleep,
"wakeup": self._handle_wakeup,
}
Expand Down
31 changes: 27 additions & 4 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,10 +1391,16 @@ def _control_pause(self, control_request: ControlRequest):
# pause cache transfer
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.llm_logger.info("Start to pause cache transfer.")
pause_transfer_request = ControlRequest(request_id="pause_transfer", method="pause")
pause_transfer_request = ControlRequest(
request_id=f"{control_request.request_id}_pause_transfer", method="pause"
)
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, pause_transfer_request))
# Wait for cache_transfer responses
asyncio.run(self._wait_for_control_responses("pause_transfer", 60, executors=["cache_transfer"]))
asyncio.run(
self._wait_for_control_responses(
f"{pause_transfer_request.request_id}", 60, executors=["cache_transfer"]
)
)
self.llm_logger.info("Successfully paused cache transfer.")

self.resource_manager.cache_manager.reset()
Expand All @@ -1421,10 +1427,14 @@ def _control_resume(self, control_request: ControlRequest) -> Optional[dict]:
# resume cache transfer
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.llm_logger.info("Start to resume cache transfer.")
resume_transfer_request = ControlRequest(request_id="resume_transfer", method="resume")
resume_transfer_request = ControlRequest(
request_id=f"{control_request.request_id}_resume_transfer", method="resume"
)
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, resume_transfer_request))
# Wait for cache_transfer responses
asyncio.run(self._wait_for_control_responses("resume_transfer", 60, executors=["cache_transfer"]))
asyncio.run(
self._wait_for_control_responses(resume_transfer_request.request_id, 60, executors=["cache_transfer"])
)
self.llm_logger.info("Successfully resumed cache transfer.")

self.llm_logger.info("Successfully resumed request generation.")
Expand Down Expand Up @@ -1479,6 +1489,19 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d
if new_version is not None:
self.cfg.model_config.version = new_version

if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.llm_logger.info("Start to update cache-transfer metadata after weight update.")
update_cache_request = ControlRequest(
request_id=f"{control_request.request_id}_update_weights",
method="update_weights",
args=copy.deepcopy(control_request.args),
)
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, update_cache_request))
asyncio.run(
self._wait_for_control_responses(update_cache_request.request_id, 60, executors=["cache_transfer"])
)
self.llm_logger.info("Successfully updated cache-transfer metadata after weight update.")

return responses

def _parse_tags(self, control_request: ControlRequest):
Expand Down
118 changes: 117 additions & 1 deletion tests/cache_manager/test_cache_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tempfile
import time
import unittest
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, Mock, patch

import paddle

Expand All @@ -37,6 +37,7 @@ def enable_torch_proxy(scope=None):
import fastdeploy.cache_manager.cache_transfer_manager as cache_transfer_manager
from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask
from fastdeploy.cache_manager.cache_transfer_manager import CacheTransferManager
from fastdeploy.engine.request import ControlRequest


# ==========================
Expand Down Expand Up @@ -121,6 +122,16 @@ def __init__(self, name, array, dtype, suffix, create=False):
patcher_thread.start()
self.addCleanup(patcher_thread.stop)

# --------------------------
# mock FMQ
# --------------------------
patcher_fmq = patch("fastdeploy.cache_manager.cache_transfer_manager.FMQ")
mock_fmq_cls = patcher_fmq.start()
mock_fmq = MagicMock()
mock_fmq.queue.return_value = MagicMock(name="ctrl_output_queue")
mock_fmq_cls.return_value = mock_fmq
self.addCleanup(patcher_fmq.stop)

# --------------------------
# mock _init_cpu_cache 和 _init_gpu_cache
# --------------------------
Expand Down Expand Up @@ -1515,6 +1526,111 @@ def resume_sleep(_):

self.assertFalse(self.manager.is_paused)

def test_init_control_builds_expected_queue_name(self):
self.manager.rank = 1
self.manager.n_ranks = 4
self.manager.local_data_parallel_id = 2
self.manager.cache_queue_port = 8899

queue = MagicMock(name="ctrl_q")
fmq = MagicMock()
fmq.queue.return_value = queue

with patch("fastdeploy.cache_manager.cache_transfer_manager.FMQ", return_value=fmq):
self.manager._init_control()

fmq.queue.assert_called_once_with("ctrl_c2e_rank9_8899", "producer")
self.assertIs(self.manager.ctrl_output_queue, queue)

def test_control_task_success_puts_control_response(self):
self.manager.cache_task_queue.barrier = MagicMock(wait=Mock())
self.manager.ctrl_output_queue = MagicMock(name="ctrl_q")
self.manager.ctrl_output_queue.put = Mock(return_value="coro")
self.manager._handle_pause = MagicMock(return_value=True)

with patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"):
self.manager.control_task(ControlRequest(request_id="ctrl-1", method="pause"))

self.manager._handle_pause.assert_called_once()
self.manager.cache_task_queue.barrier.wait.assert_called_once()
self.manager.ctrl_output_queue.put.assert_called_once()
response = self.manager.ctrl_output_queue.put.call_args.args[0]
self.assertEqual(response.request_id, "ctrl-1")
self.assertEqual(response.error_code, 200)

def test_control_task_unknown_method_returns_400(self):
self.manager.cache_task_queue.barrier = MagicMock(wait=Mock())
self.manager.ctrl_output_queue = MagicMock(name="ctrl_q")
self.manager.ctrl_output_queue.put = Mock(return_value="coro")

with patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"):
self.manager.control_task(ControlRequest(request_id="ctrl-2", method="unknown"))

response = self.manager.ctrl_output_queue.put.call_args.args[0]
self.assertEqual(response.error_code, 400)
self.assertIn("Unknown control method", response.error_message)

def test_control_task_exception_returns_500(self):
self.manager.cache_task_queue.barrier = MagicMock(wait=Mock())
self.manager.ctrl_output_queue = MagicMock(name="ctrl_q")
self.manager.ctrl_output_queue.put = Mock(return_value="coro")

with (
patch.object(self.manager, "_handle_sleep", side_effect=RuntimeError("boom")),
patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"),
):
self.manager.control_task(ControlRequest(request_id="ctrl-3", method="sleep"))

response = self.manager.ctrl_output_queue.put.call_args.args[0]
self.assertEqual(response.error_code, 500)
self.assertIn("Failed to execute sleep", response.error_message)

def test_handle_resume_updates_key_prefix_for_storage_backend(self):
self.manager.is_paused = True
self.manager.storage_backend_type = "mooncake"
self.manager.resume = MagicMock()
self.manager._update_key_prefix = MagicMock()

result = self.manager._handle_resume()

self.assertTrue(result)
self.manager.resume.assert_called_once()
self.manager._update_key_prefix.assert_called_once()

def test_handle_update_weights_updates_key_prefix_for_storage_backend(self):
self.manager.storage_backend_type = "mooncake"
self.manager._update_key_prefix = MagicMock()

result = self.manager._handle_update_weights()

self.assertTrue(result)
self.manager._update_key_prefix.assert_called_once()

def test_handle_update_weights_skips_without_storage_backend(self):
self.manager.storage_backend_type = None
self.manager._update_key_prefix = MagicMock()

result = self.manager._handle_update_weights()

self.assertTrue(result)
self.manager._update_key_prefix.assert_not_called()

def test_handle_sleep_and_wakeup_are_idempotent(self):
self.manager.is_sleeping = True
self.manager._clear_cpu_cache = MagicMock()
self.manager._clear_gpu_cache = MagicMock()
self.manager._init_cpu_cache = MagicMock()
self.manager._init_gpu_cache = MagicMock()

self.assertTrue(self.manager._handle_sleep())
self.manager._clear_cpu_cache.assert_not_called()
self.manager._clear_gpu_cache.assert_not_called()

self.manager.is_sleeping = False
self.assertTrue(self.manager._handle_wakeup())
self.manager._init_cpu_cache.assert_not_called()
self.manager._init_gpu_cache.assert_not_called()

def test_submit_task_decrements_inflight_on_task_error(self):
class DummyPool:
def submit(self, fn, *args):
Expand Down
22 changes: 22 additions & 0 deletions tests/engine/test_common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def enable_torch_proxy(scope=None):

paddle.compat = _PaddleCompat()

from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.common_engine import EngineService
from fastdeploy.engine.request import (
Expand Down Expand Up @@ -1104,6 +1105,27 @@ def test_control_update_weights_updates_cfg_version(self):
self.assertEqual(eng.cfg.model_config.version, "new-version")
self._detach_finalizer(eng)

def test_control_update_weights_updates_cache_transfer_metadata(self):
eng = self._make_mixed_engine()
eng.is_paused = True
eng._pause_cond = threading.Condition()
eng.cfg.cache_config.num_cpu_blocks = 1
eng._call_worker = Mock(return_value=[{"version": "new-version"}])
eng.cache_task_queue = Mock(put_transfer_task=Mock())
eng._wait_for_control_responses = AsyncMock(return_value=[{"ok": True}])

result = eng._control_update_weights(ControlRequest(request_id="ctrl", method="update_weights"))

self.assertEqual(result, [{"version": "new-version"}])
payload = eng.cache_task_queue.put_transfer_task.call_args.args[0]
self.assertEqual(payload[0], CacheStatus.CTRL)
self.assertEqual(payload[1].method, "update_weights")
self.assertIn("update_weights", payload[1].request_id)
eng._wait_for_control_responses.assert_awaited_once_with(
payload[1].request_id, 60, executors=["cache_transfer"]
)
self._detach_finalizer(eng)

def test_control_pause_and_resume_paths(self):
eng = self._make_mixed_engine()
eng.is_paused = False
Expand Down
61 changes: 60 additions & 1 deletion tests/entrypoints/test_engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import paddle
import pytest

from fastdeploy.engine.request import ControlRequest
from fastdeploy.engine.request import ControlRequest, ControlResponse
from fastdeploy.entrypoints.engine_client import EngineClient
from fastdeploy.inter_communicator import (
KVCacheStatus,
Expand Down Expand Up @@ -1882,6 +1882,65 @@ def test_valid_parameters_and_control_timeout(minimal_engine_client):
assert resp.error_code == 500


def test_run_control_method_uses_send_pyobj_for_mm_requests(minimal_engine_client):
queue = asyncio.Queue()
asyncio.run(queue.put(({"request_id": "mm-1", "status": 200, "msg": "ok"},)))
dealer = Mock(write=Mock())
minimal_engine_client.enable_mm = True
minimal_engine_client.connection_manager = MagicMock(get_connection=AsyncMock(return_value=(dealer, queue)))

with patch("fastdeploy.entrypoints.engine_client.envs.ZMQ_SEND_BATCH_DATA", 0):
resp = asyncio.run(minimal_engine_client.run_control_method(ControlRequest(request_id="mm-1", method="ping")))

assert resp.error_code == 200
minimal_engine_client.zmq_client.send_pyobj.assert_called_once()
minimal_engine_client.zmq_client.send_json.assert_not_called()


def test_run_control_method_adds_worker_pid_in_batch_mode(minimal_engine_client):
queue = asyncio.Queue()
asyncio.run(queue.put(({"request_id": "batch-1", "status": 200, "msg": "ok"},)))
minimal_engine_client.connection_manager = MagicMock(get_connection=AsyncMock(return_value=(None, queue)))

with patch("fastdeploy.entrypoints.engine_client.envs.ZMQ_SEND_BATCH_DATA", 1):
resp = asyncio.run(
minimal_engine_client.run_control_method(ControlRequest(request_id="batch-1", method="ping"))
)

assert resp.error_code == 200
payload = minimal_engine_client.zmq_client.send_json.call_args.args[0]
assert payload["zmq_worker_pid"] == minimal_engine_client.worker_pid


def test_run_control_method_generic_exception_returns_error(minimal_engine_client):
queue = MagicMock()
queue.get = AsyncMock(side_effect=RuntimeError("queue failed"))
dealer = Mock(write=Mock())
minimal_engine_client.connection_manager = MagicMock(get_connection=AsyncMock(return_value=(dealer, queue)))

with patch("fastdeploy.entrypoints.engine_client.envs.ZMQ_SEND_BATCH_DATA", 0):
resp = asyncio.run(minimal_engine_client.run_control_method(ControlRequest(request_id="r3", method="m")))

assert resp.error_code == 500
assert "queue failed" in resp.error_message


def test_run_control_method_sync_uses_threadsafe_bridge(minimal_engine_client):
req = ControlRequest(request_id="sync-1", method="ping")
future = Mock(result=Mock(return_value=ControlResponse("sync-1", 200, "Success")))

minimal_engine_client.run_control_method = AsyncMock(return_value=ControlResponse("sync-1", 200, "Success"))

with patch(
"fastdeploy.entrypoints.engine_client.asyncio.run_coroutine_threadsafe", return_value=future
) as mock_run:
resp = minimal_engine_client.run_control_method_sync(req, Mock())

assert resp.error_code == 200
mock_run.assert_called_once()
mock_run.call_args.args[0].close()


def test_rearrange_and_redundant_branch_matrix(minimal_engine_client):
cfg = create_mock_fd_config(enable_eplb=True)
cfg.parallel_config.tensor_parallel_rank = 0
Expand Down
Loading
Loading