diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index b264f03b753..26a08e8eec6 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -1098,6 +1098,14 @@ def _handle_resume(self): logger.info("✅ Successfully resumed transfer") return True + def _handle_update_weights(self): + if self.storage_backend_type is not None: + self._update_key_prefix() + logger.info("✅ Successfully updated cache key prefix after weight update") + else: + logger.info("💡 Cache storage backend is disabled, skip updating cache key prefix") + return True + def _handle_sleep(self): if self.is_sleeping: logger.info("💡 Cache transfer manager is already sleeping, no need to sleep again!") @@ -1128,6 +1136,7 @@ def control_task(self, task: ControlRequest): handlers = { "pause": self._handle_pause, "resume": self._handle_resume, + "update_weights": self._handle_update_weights, "sleep": self._handle_sleep, "wakeup": self._handle_wakeup, } diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index fdc160735b7..26357eeeee2 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1411,10 +1411,16 @@ def _control_pause(self, control_request: ControlRequest): # pause cache transfer if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: self.llm_logger.info("Start to pause cache transfer.") - pause_transfer_request = ControlRequest(request_id="pause_transfer", method="pause") + pause_transfer_request = ControlRequest( + request_id=f"{control_request.request_id}_pause_transfer", method="pause" + ) self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, pause_transfer_request)) # Wait for cache_transfer responses - asyncio.run(self._wait_for_control_responses("pause_transfer", 60, executors=["cache_transfer"])) + asyncio.run( + self._wait_for_control_responses( + f"{pause_transfer_request.request_id}", 60, executors=["cache_transfer"] + ) + ) self.llm_logger.info("Successfully paused cache transfer.") self.resource_manager.cache_manager.reset() @@ -1441,10 +1447,14 @@ def _control_resume(self, control_request: ControlRequest) -> Optional[dict]: # resume cache transfer if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: self.llm_logger.info("Start to resume cache transfer.") - resume_transfer_request = ControlRequest(request_id="resume_transfer", method="resume") + resume_transfer_request = ControlRequest( + request_id=f"{control_request.request_id}_resume_transfer", method="resume" + ) self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, resume_transfer_request)) # Wait for cache_transfer responses - asyncio.run(self._wait_for_control_responses("resume_transfer", 60, executors=["cache_transfer"])) + asyncio.run( + self._wait_for_control_responses(resume_transfer_request.request_id, 60, executors=["cache_transfer"]) + ) self.llm_logger.info("Successfully resumed cache transfer.") self.llm_logger.info("Successfully resumed request generation.") @@ -1499,6 +1509,19 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d if new_version is not None: self.cfg.model_config.version = new_version + if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: + self.llm_logger.info("Start to update cache-transfer metadata after weight update.") + update_cache_request = ControlRequest( + request_id=f"{control_request.request_id}_update_weights", + method="update_weights", + args=copy.deepcopy(control_request.args), + ) + self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, update_cache_request)) + asyncio.run( + self._wait_for_control_responses(update_cache_request.request_id, 60, executors=["cache_transfer"]) + ) + self.llm_logger.info("Successfully updated cache-transfer metadata after weight update.") + return responses def _control_abort_requests(self, control_req: ControlRequest): diff --git a/tests/cache_manager/test_cache_transfer_manager.py b/tests/cache_manager/test_cache_transfer_manager.py index 191d1ad36e0..76419eba8cd 100644 --- a/tests/cache_manager/test_cache_transfer_manager.py +++ b/tests/cache_manager/test_cache_transfer_manager.py @@ -17,7 +17,7 @@ import tempfile import time import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import paddle @@ -37,6 +37,7 @@ def enable_torch_proxy(scope=None): import fastdeploy.cache_manager.cache_transfer_manager as cache_transfer_manager from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask from fastdeploy.cache_manager.cache_transfer_manager import CacheTransferManager +from fastdeploy.engine.request import ControlRequest # ========================== @@ -121,6 +122,16 @@ def __init__(self, name, array, dtype, suffix, create=False): patcher_thread.start() self.addCleanup(patcher_thread.stop) + # -------------------------- + # mock FMQ + # -------------------------- + patcher_fmq = patch("fastdeploy.cache_manager.cache_transfer_manager.FMQ") + mock_fmq_cls = patcher_fmq.start() + mock_fmq = MagicMock() + mock_fmq.queue.return_value = MagicMock(name="ctrl_output_queue") + mock_fmq_cls.return_value = mock_fmq + self.addCleanup(patcher_fmq.stop) + # -------------------------- # mock _init_cpu_cache 和 _init_gpu_cache # -------------------------- @@ -1515,6 +1526,111 @@ def resume_sleep(_): self.assertFalse(self.manager.is_paused) + def test_init_control_builds_expected_queue_name(self): + self.manager.rank = 1 + self.manager.n_ranks = 4 + self.manager.local_data_parallel_id = 2 + self.manager.cache_queue_port = 8899 + + queue = MagicMock(name="ctrl_q") + fmq = MagicMock() + fmq.queue.return_value = queue + + with patch("fastdeploy.cache_manager.cache_transfer_manager.FMQ", return_value=fmq): + self.manager._init_control() + + fmq.queue.assert_called_once_with("ctrl_c2e_rank9_8899", "producer") + self.assertIs(self.manager.ctrl_output_queue, queue) + + def test_control_task_success_puts_control_response(self): + self.manager.cache_task_queue.barrier = MagicMock(wait=Mock()) + self.manager.ctrl_output_queue = MagicMock(name="ctrl_q") + self.manager.ctrl_output_queue.put = Mock(return_value="coro") + self.manager._handle_pause = MagicMock(return_value=True) + + with patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"): + self.manager.control_task(ControlRequest(request_id="ctrl-1", method="pause")) + + self.manager._handle_pause.assert_called_once() + self.manager.cache_task_queue.barrier.wait.assert_called_once() + self.manager.ctrl_output_queue.put.assert_called_once() + response = self.manager.ctrl_output_queue.put.call_args.args[0] + self.assertEqual(response.request_id, "ctrl-1") + self.assertEqual(response.error_code, 200) + + def test_control_task_unknown_method_returns_400(self): + self.manager.cache_task_queue.barrier = MagicMock(wait=Mock()) + self.manager.ctrl_output_queue = MagicMock(name="ctrl_q") + self.manager.ctrl_output_queue.put = Mock(return_value="coro") + + with patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"): + self.manager.control_task(ControlRequest(request_id="ctrl-2", method="unknown")) + + response = self.manager.ctrl_output_queue.put.call_args.args[0] + self.assertEqual(response.error_code, 400) + self.assertIn("Unknown control method", response.error_message) + + def test_control_task_exception_returns_500(self): + self.manager.cache_task_queue.barrier = MagicMock(wait=Mock()) + self.manager.ctrl_output_queue = MagicMock(name="ctrl_q") + self.manager.ctrl_output_queue.put = Mock(return_value="coro") + + with ( + patch.object(self.manager, "_handle_sleep", side_effect=RuntimeError("boom")), + patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"), + ): + self.manager.control_task(ControlRequest(request_id="ctrl-3", method="sleep")) + + response = self.manager.ctrl_output_queue.put.call_args.args[0] + self.assertEqual(response.error_code, 500) + self.assertIn("Failed to execute sleep", response.error_message) + + def test_handle_resume_updates_key_prefix_for_storage_backend(self): + self.manager.is_paused = True + self.manager.storage_backend_type = "mooncake" + self.manager.resume = MagicMock() + self.manager._update_key_prefix = MagicMock() + + result = self.manager._handle_resume() + + self.assertTrue(result) + self.manager.resume.assert_called_once() + self.manager._update_key_prefix.assert_called_once() + + def test_handle_update_weights_updates_key_prefix_for_storage_backend(self): + self.manager.storage_backend_type = "mooncake" + self.manager._update_key_prefix = MagicMock() + + result = self.manager._handle_update_weights() + + self.assertTrue(result) + self.manager._update_key_prefix.assert_called_once() + + def test_handle_update_weights_skips_without_storage_backend(self): + self.manager.storage_backend_type = None + self.manager._update_key_prefix = MagicMock() + + result = self.manager._handle_update_weights() + + self.assertTrue(result) + self.manager._update_key_prefix.assert_not_called() + + def test_handle_sleep_and_wakeup_are_idempotent(self): + self.manager.is_sleeping = True + self.manager._clear_cpu_cache = MagicMock() + self.manager._clear_gpu_cache = MagicMock() + self.manager._init_cpu_cache = MagicMock() + self.manager._init_gpu_cache = MagicMock() + + self.assertTrue(self.manager._handle_sleep()) + self.manager._clear_cpu_cache.assert_not_called() + self.manager._clear_gpu_cache.assert_not_called() + + self.manager.is_sleeping = False + self.assertTrue(self.manager._handle_wakeup()) + self.manager._init_cpu_cache.assert_not_called() + self.manager._init_gpu_cache.assert_not_called() + def test_submit_task_decrements_inflight_on_task_error(self): class DummyPool: def submit(self, fn, *args): diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 44f784cff33..5b70c42fdbd 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -38,6 +38,7 @@ def enable_torch_proxy(scope=None): paddle.compat = _PaddleCompat() +from fastdeploy.cache_manager.cache_data import CacheStatus from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.common_engine import EngineService from fastdeploy.engine.request import ( @@ -1112,6 +1113,27 @@ def test_control_update_weights_updates_cfg_version(self): self.assertEqual(eng.cfg.model_config.version, "new-version") self._detach_finalizer(eng) + def test_control_update_weights_updates_cache_transfer_metadata(self): + eng = self._make_mixed_engine() + eng.is_paused = True + eng._pause_cond = threading.Condition() + eng.cfg.cache_config.num_cpu_blocks = 1 + eng._call_worker = Mock(return_value=[{"version": "new-version"}]) + eng.cache_task_queue = Mock(put_transfer_task=Mock()) + eng._wait_for_control_responses = AsyncMock(return_value=[{"ok": True}]) + + result = eng._control_update_weights(ControlRequest(request_id="ctrl", method="update_weights")) + + self.assertEqual(result, [{"version": "new-version"}]) + payload = eng.cache_task_queue.put_transfer_task.call_args.args[0] + self.assertEqual(payload[0], CacheStatus.CTRL) + self.assertEqual(payload[1].method, "update_weights") + self.assertIn("update_weights", payload[1].request_id) + eng._wait_for_control_responses.assert_awaited_once_with( + payload[1].request_id, 60, executors=["cache_transfer"] + ) + self._detach_finalizer(eng) + def test_control_pause_and_resume_paths(self): eng = self._make_mixed_engine() eng.is_paused = False diff --git a/tests/entrypoints/test_engine_client.py b/tests/entrypoints/test_engine_client.py index fa7bb4509da..0ed8fbdc033 100644 --- a/tests/entrypoints/test_engine_client.py +++ b/tests/entrypoints/test_engine_client.py @@ -25,7 +25,7 @@ import paddle import pytest -from fastdeploy.engine.request import ControlRequest +from fastdeploy.engine.request import ControlRequest, ControlResponse from fastdeploy.entrypoints.engine_client import EngineClient from fastdeploy.inter_communicator import ( KVCacheStatus, @@ -1882,6 +1882,65 @@ def test_valid_parameters_and_control_timeout(minimal_engine_client): assert resp.error_code == 500 +def test_run_control_method_uses_send_pyobj_for_mm_requests(minimal_engine_client): + queue = asyncio.Queue() + asyncio.run(queue.put(({"request_id": "mm-1", "status": 200, "msg": "ok"},))) + dealer = Mock(write=Mock()) + minimal_engine_client.enable_mm = True + minimal_engine_client.connection_manager = MagicMock(get_connection=AsyncMock(return_value=(dealer, queue))) + + with patch("fastdeploy.entrypoints.engine_client.envs.ZMQ_SEND_BATCH_DATA", 0): + resp = asyncio.run(minimal_engine_client.run_control_method(ControlRequest(request_id="mm-1", method="ping"))) + + assert resp.error_code == 200 + minimal_engine_client.zmq_client.send_pyobj.assert_called_once() + minimal_engine_client.zmq_client.send_json.assert_not_called() + + +def test_run_control_method_adds_worker_pid_in_batch_mode(minimal_engine_client): + queue = asyncio.Queue() + asyncio.run(queue.put(({"request_id": "batch-1", "status": 200, "msg": "ok"},))) + minimal_engine_client.connection_manager = MagicMock(get_connection=AsyncMock(return_value=(None, queue))) + + with patch("fastdeploy.entrypoints.engine_client.envs.ZMQ_SEND_BATCH_DATA", 1): + resp = asyncio.run( + minimal_engine_client.run_control_method(ControlRequest(request_id="batch-1", method="ping")) + ) + + assert resp.error_code == 200 + payload = minimal_engine_client.zmq_client.send_json.call_args.args[0] + assert payload["zmq_worker_pid"] == minimal_engine_client.worker_pid + + +def test_run_control_method_generic_exception_returns_error(minimal_engine_client): + queue = MagicMock() + queue.get = AsyncMock(side_effect=RuntimeError("queue failed")) + dealer = Mock(write=Mock()) + minimal_engine_client.connection_manager = MagicMock(get_connection=AsyncMock(return_value=(dealer, queue))) + + with patch("fastdeploy.entrypoints.engine_client.envs.ZMQ_SEND_BATCH_DATA", 0): + resp = asyncio.run(minimal_engine_client.run_control_method(ControlRequest(request_id="r3", method="m"))) + + assert resp.error_code == 500 + assert "queue failed" in resp.error_message + + +def test_run_control_method_sync_uses_threadsafe_bridge(minimal_engine_client): + req = ControlRequest(request_id="sync-1", method="ping") + future = Mock(result=Mock(return_value=ControlResponse("sync-1", 200, "Success"))) + + minimal_engine_client.run_control_method = AsyncMock(return_value=ControlResponse("sync-1", 200, "Success")) + + with patch( + "fastdeploy.entrypoints.engine_client.asyncio.run_coroutine_threadsafe", return_value=future + ) as mock_run: + resp = minimal_engine_client.run_control_method_sync(req, Mock()) + + assert resp.error_code == 200 + mock_run.assert_called_once() + mock_run.call_args.args[0].close() + + def test_rearrange_and_redundant_branch_matrix(minimal_engine_client): cfg = create_mock_fd_config(enable_eplb=True) cfg.parallel_config.tensor_parallel_rank = 0 diff --git a/tests/worker/test_gpu_model_runner.py b/tests/worker/test_gpu_model_runner.py index d135cf11622..3a02475b5ae 100644 --- a/tests/worker/test_gpu_model_runner.py +++ b/tests/worker/test_gpu_model_runner.py @@ -14,12 +14,13 @@ import unittest from dataclasses import dataclass -from unittest.mock import Mock +from unittest.mock import Mock, patch import numpy as np import paddle from fastdeploy.engine.request import ImagePosition +from fastdeploy.spec_decode import SpecMethod from fastdeploy.worker.gpu_model_runner import GPUModelRunner from fastdeploy.worker.input_batch import InputBatch @@ -476,5 +477,119 @@ def test_process_mm_features_no_encoder_cache(self): ) +class TestSleepWakeupBehavior(unittest.TestCase): + def _make_runner(self): + runner = GPUModelRunner.__new__(GPUModelRunner) + runner.is_weight_sleeping = False + runner.is_kvcache_sleeping = False + runner.use_cudagraph = False + runner.spec_method = None + runner.local_rank = 0 + runner.device_id = 1 + runner.num_gpu_blocks = 8 + runner.model = Mock(clear_grpah_opt_backend=Mock()) + runner.clear_cache = Mock() + runner.initialize_kv_cache = Mock() + runner.capture_model = Mock() + runner.share_inputs = Mock(reset_share_inputs=Mock()) + runner.dynamic_weight_manager = Mock( + clear_deepep_buffer=Mock(), + clear_model_weight=Mock(), + clear_communication_group=Mock(), + restart_communication_group=Mock(), + recreate_deepep_buffer=Mock(), + reload_model_weights=Mock(), + ) + runner.fd_config = Mock() + runner.fd_config.parallel_config = Mock( + enable_expert_parallel=False, + shutdown_comm_group_if_worker_idle=False, + ) + runner.proposer = Mock( + clear_mtp_cache=Mock(), + initialize_kv_cache=Mock(), + model_inputs=Mock(reset_model_inputs=Mock()), + ) + return runner + + @patch("fastdeploy.worker.gpu_model_runner.print_gpu_memory_use") + @patch("paddle.device.cuda.empty_cache") + def test_sleep_offloads_weight_and_cache(self, mock_empty_cache, mock_print_memory): + runner = self._make_runner() + runner.use_cudagraph = True + runner.spec_method = SpecMethod.MTP + runner.fd_config.parallel_config.enable_expert_parallel = True + runner.fd_config.parallel_config.shutdown_comm_group_if_worker_idle = True + + runner.sleep("weight,kv_cache") + + runner.model.clear_grpah_opt_backend.assert_called_once() + runner.dynamic_weight_manager.clear_deepep_buffer.assert_called_once() + runner.dynamic_weight_manager.clear_model_weight.assert_called_once() + runner.dynamic_weight_manager.clear_communication_group.assert_called_once() + runner.proposer.clear_mtp_cache.assert_called_once() + runner.clear_cache.assert_called_once() + self.assertTrue(runner.is_weight_sleeping) + self.assertTrue(runner.is_kvcache_sleeping) + mock_empty_cache.assert_called_once() + mock_print_memory.assert_called_once() + + @patch("fastdeploy.worker.gpu_model_runner.print_gpu_memory_use") + @patch("paddle.device.cuda.empty_cache") + def test_sleep_weight_is_idempotent(self, mock_empty_cache, mock_print_memory): + runner = self._make_runner() + runner.is_weight_sleeping = True + + runner.sleep("weight") + + runner.dynamic_weight_manager.clear_model_weight.assert_not_called() + runner.clear_cache.assert_not_called() + mock_empty_cache.assert_not_called() + mock_print_memory.assert_not_called() + + def test_wakeup_rejects_weight_only_when_cudagraph_requires_kvcache(self): + runner = self._make_runner() + runner.use_cudagraph = True + runner.is_kvcache_sleeping = True + + with self.assertRaises(RuntimeError): + runner.wakeup("weight") + + @patch("fastdeploy.worker.gpu_model_runner.print_gpu_memory_use") + def test_wakeup_restores_weight_and_cache(self, mock_print_memory): + runner = self._make_runner() + runner.use_cudagraph = True + runner.spec_method = SpecMethod.MTP + runner.is_weight_sleeping = True + runner.is_kvcache_sleeping = True + runner.fd_config.parallel_config.enable_expert_parallel = True + runner.fd_config.parallel_config.shutdown_comm_group_if_worker_idle = True + + runner.wakeup("weight,kv_cache") + + runner.proposer.model_inputs.reset_model_inputs.assert_called_once() + runner.share_inputs.reset_share_inputs.assert_called_once() + runner.proposer.initialize_kv_cache.assert_called_once_with(main_model_num_blocks=runner.num_gpu_blocks) + runner.initialize_kv_cache.assert_called_once() + runner.dynamic_weight_manager.restart_communication_group.assert_called_once() + runner.dynamic_weight_manager.recreate_deepep_buffer.assert_called_once() + runner.dynamic_weight_manager.reload_model_weights.assert_called_once() + runner.capture_model.assert_called_once() + self.assertFalse(runner.is_weight_sleeping) + self.assertFalse(runner.is_kvcache_sleeping) + mock_print_memory.assert_called_once() + + @patch("fastdeploy.worker.gpu_model_runner.print_gpu_memory_use") + def test_wakeup_kvcache_is_idempotent(self, mock_print_memory): + runner = self._make_runner() + runner.is_kvcache_sleeping = False + + runner.wakeup("kv_cache") + + runner.initialize_kv_cache.assert_not_called() + runner.dynamic_weight_manager.reload_model_weights.assert_not_called() + mock_print_memory.assert_not_called() + + if __name__ == "__main__": unittest.main() diff --git a/tests/worker/test_gpu_worker.py b/tests/worker/test_gpu_worker.py new file mode 100644 index 00000000000..14f8dc1c02c --- /dev/null +++ b/tests/worker/test_gpu_worker.py @@ -0,0 +1,99 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import Mock + +from fastdeploy.config import FDConfig +from fastdeploy.worker.gpu_worker import GpuWorker + + +class TestGpuWorkerSleepWakeup(unittest.TestCase): + """Test cases for GpuWorker sleep and wakeup methods - Coverage for lines 201, 205""" + + def setUp(self): + """Set up test fixtures""" + self.mock_fd_config = Mock(spec=FDConfig) + self.mock_fd_config.parallel_config = Mock() + self.mock_fd_config.parallel_config.tensor_parallel_size = 1 + + def test_sleep_delegates_to_model_runner(self): + """Test sleep method delegates to model_runner (line 201)""" + worker = GpuWorker.__new__(GpuWorker) + worker.model_runner = Mock() + + # Call sleep + worker.sleep(tags="weight") + + # Verify model_runner.sleep was called + worker.model_runner.sleep.assert_called_once_with(tags="weight") + + def test_sleep_with_multiple_tags(self): + """Test sleep with multiple tags""" + worker = GpuWorker.__new__(GpuWorker) + worker.model_runner = Mock() + + # Call sleep with multiple tags + worker.sleep(tags="weight,kv_cache") + + # Verify model_runner.sleep was called with correct tags + worker.model_runner.sleep.assert_called_once_with(tags="weight,kv_cache") + + def test_sleep_with_kwargs(self): + """Test sleep passes kwargs to model_runner""" + worker = GpuWorker.__new__(GpuWorker) + worker.model_runner = Mock() + + # Call sleep with kwargs + worker.sleep(tags="weight", force=True, timeout=100) + + # Verify model_runner.sleep was called with kwargs + worker.model_runner.sleep.assert_called_once_with(tags="weight", force=True, timeout=100) + + def test_wakeup_delegates_to_model_runner(self): + """Test wakeup method delegates to model_runner (line 205)""" + worker = GpuWorker.__new__(GpuWorker) + worker.model_runner = Mock() + + # Call wakeup + worker.wakeup(tags="weight") + + # Verify model_runner.wakeup was called + worker.model_runner.wakeup.assert_called_once_with(tags="weight") + + def test_wakeup_with_multiple_tags(self): + """Test wakeup with multiple tags""" + worker = GpuWorker.__new__(GpuWorker) + worker.model_runner = Mock() + + # Call wakeup with multiple tags + worker.wakeup(tags="weight,kv_cache") + + # Verify model_runner.wakeup was called with correct tags + worker.model_runner.wakeup.assert_called_once_with(tags="weight,kv_cache") + + def test_wakeup_with_kwargs(self): + """Test wakeup passes kwargs to model_runner""" + worker = GpuWorker.__new__(GpuWorker) + worker.model_runner = Mock() + + # Call wakeup with kwargs + worker.wakeup(tags="kv_cache", async_load=True) + + # Verify model_runner.wakeup was called with kwargs + worker.model_runner.wakeup.assert_called_once_with(tags="kv_cache", async_load=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/worker/test_worker_process.py b/tests/worker/test_worker_process.py index 19430fafce2..6dd2ad62599 100644 --- a/tests/worker/test_worker_process.py +++ b/tests/worker/test_worker_process.py @@ -13,14 +13,20 @@ # limitations under the License. import logging +import types import unittest +from unittest.mock import AsyncMock, Mock, patch + +from fastdeploy.config import FDConfig +from fastdeploy.engine.request import ControlRequest +from fastdeploy.worker.worker_process import PaddleDisWorkerProc class TestInterceptPaddleLoggers(unittest.TestCase): """Test cases for intercept_paddle_loggers context manager from tools.logger_patch""" def test_intercept_paddle_loggers_with_paddle_prefix(self): - """Test intercept_paddle_loggers configures paddle loggers correctly (line 28-30)""" + """Test intercept_paddle_loggers configures paddle loggers correctly""" from fastdeploy.logger.logger import intercept_paddle_loggers # Create a logger with existing handlers before interception @@ -34,12 +40,12 @@ def test_intercept_paddle_loggers_with_paddle_prefix(self): test_logger.addHandler(handler2) self.assertEqual(len(test_logger.handlers), 2) - # Use the context manager to intercept paddle loggers + # Use context manager to intercept paddle loggers with intercept_paddle_loggers(): # Get logger inside context - should be configured by interceptor intercepted_logger = logging.getLogger(test_logger_name) - # Verify the logger was reconfigured by the interceptor + # Verify the logger was reconfigured by interceptor self.assertEqual(len(intercepted_logger.handlers), 1) self.assertIsInstance(intercepted_logger.handlers[0], logging.StreamHandler) self.assertEqual(intercepted_logger.level, logging.INFO) @@ -49,7 +55,7 @@ def test_intercept_paddle_loggers_with_paddle_prefix(self): test_logger.handlers = [] def test_intercept_paddle_loggers_restores_original(self): - """Test intercept_paddle_loggers restores original getLogger after exit (line 46)""" + """Test intercept_paddle_loggers restores original getLogger after exit""" from fastdeploy.logger.logger import intercept_paddle_loggers # Store original getLogger before context @@ -104,5 +110,178 @@ def test_intercept_paddle_loggers_exception_safety(self): self.assertEqual(logging.getLogger, original_getLogger) +class TestWorkerProcessControlMethod(unittest.TestCase): + """Test cases for PaddleDisWorkerProc control method handling - Coverage for lines 761-786""" + + def setUp(self): + """Set up test fixtures""" + self.mock_fd_config = Mock(spec=FDConfig) + self.mock_fd_config.parallel_config = Mock() + self.mock_fd_config.parallel_config.use_ep = False + self.mock_fd_config.parallel_config.tensor_parallel_size = 1 + self.mock_fd_config.load_config = Mock() + self.mock_fd_config.load_config.dynamic_load_weight = False + + self.process = PaddleDisWorkerProc.__new__(PaddleDisWorkerProc) + self.process.fd_config = self.mock_fd_config + self.process.parallel_config = self.mock_fd_config.parallel_config + self.process.local_rank = 0 + self.process.eplb_config = types.SimpleNamespace(enable_eplb=False) + + # Mock worker - use spec to avoid auto-creating Mock methods + self.process.worker = Mock(spec=[]) # Empty spec = no methods defined + + # Create async mock for queue + self.mock_queue = Mock() + self.mock_queue.put = AsyncMock() + self.process._ctrl_output = self.mock_queue + + def test_run_control_method_unknown_handler(self): + """Test run_control_method with unknown control method""" + # Create a request with unknown method + request = ControlRequest(request_id="test_id", method="unknown_method", args={}) + + self.process.run_control_method(request) + + # Verify put was called with error response + self.mock_queue.put.assert_called_once() + call_args = self.mock_queue.put.call_args[0][0] + self.assertEqual(call_args.request_id, "test_id") + self.assertEqual(call_args.error_code, 400) + + def test_run_control_method_non_callable_handler(self): + """Test run_control_method with non-callable handler""" + # Add a non-callable attribute to worker + self.process.worker.some_method = "not_callable" + + request = ControlRequest(request_id="test_id", method="some_method", args={}) + + self.process.run_control_method(request) + + # Verify put was called with error response + self.mock_queue.put.assert_called_once() + call_args = self.mock_queue.put.call_args[0][0] + self.assertEqual(call_args.error_code, 400) + + def test_run_control_method_success(self): + """Test run_control_method with successful execution""" + # Add a callable method to worker + mock_result = {"result": "success"} + self.process.worker.test_method = Mock(return_value=mock_result) + + request = ControlRequest(request_id="test_id", method="test_method", args={"param": "value"}) + + self.process.run_control_method(request) + + # Verify handler was called with args + self.process.worker.test_method.assert_called_once_with(param="value") + + # Verify put was called with success response + self.mock_queue.put.assert_called_once() + call_args = self.mock_queue.put.call_args[0][0] + self.assertEqual(call_args.request_id, "test_id") + self.assertEqual(call_args.error_code, 200) + + def test_run_control_method_exception(self): + """Test run_control_method with exception in handler""" + + # Add a method that raises exception + def failing_method(**kwargs): + raise ValueError("Test error") + + self.process.worker.test_method = failing_method + + request = ControlRequest(request_id="test_id", method="test_method", args={}) + + with patch("fastdeploy.worker.worker_process.traceback") as mock_traceback: + mock_traceback.format_exc.return_value = "Traceback..." + + self.process.run_control_method(request) + + # Verify put was called with error response + self.mock_queue.put.assert_called_once() + call_args = self.mock_queue.put.call_args[0][0] + self.assertEqual(call_args.request_id, "test_id") + self.assertEqual(call_args.error_code, 500) + + def test_run_control_directly_when_not_use_ep(self): + """Test running control request directly when use_ep is disabled""" + self.process.parallel_config.use_ep = False + + # Add a callable method to worker + self.process.worker.test_method = Mock(return_value={"result": "ok"}) + + control_req = ControlRequest(request_id="test_id", method="test_method", args={}) + + self.process.run_control_method(control_req) + + # Verify handler was called + self.process.worker.test_method.assert_called_once() + + # Verify put was called + self.mock_queue.put.assert_called_once() + + def test_event_loop_caches_ep_control_requests_before_collective_run(self): + self.process.parallel_config.use_ep = True + self.process.parallel_config.ep_group = Mock(world_size=1) + self.process.cached_control_reqs = [] + self.process._run_eplb = Mock() + self.process._tp_barrier_wait = Mock() + self.process.run_control_method = Mock() + self.process.worker_healthy_live_signal = Mock(value=[0]) + self.process.max_chips_per_node = 8 + self.process.nnode = 1 + self.process.ranks = 1 + self.process.task_queue = Mock() + self.process.task_queue.exist_tasks.return_value = False + self.process.task_queue.read_finish_flag = types.SimpleNamespace(get=Mock(return_value=1)) + control_req = ControlRequest(request_id="ep-ctrl", method="pause", args={}) + self.process.task_queue.get_tasks.return_value = ([([control_req], 1)], False) + self.process.exist_task_signal = types.SimpleNamespace(value=[1]) + self.process.worker = types.SimpleNamespace( + preprocess_new_task=Mock(), + model_runner=types.SimpleNamespace(), + execute_model=Mock(), + exist_prefill=Mock(return_value=False), + ) + with ( + patch("fastdeploy.utils.all_gather_values", side_effect=SystemExit), + patch("fastdeploy.worker.worker_process.all_gather_values", side_effect=SystemExit), + ): + with self.assertRaises(SystemExit): + self.process.event_loop_normal() + + self.assertEqual(self.process.cached_control_reqs, [control_req]) + self.process.run_control_method.assert_not_called() + + def test_event_loop_skips_execute_model_when_runner_is_sleeping(self): + self.process.parallel_config.use_ep = False + self.process.parallel_config.tensor_parallel_size = 2 + self.process.fd_config.load_config.dynamic_load_weight = True + self.process.cached_control_reqs = [] + self.process._run_eplb = Mock() + self.process._tp_barrier_wait = Mock(side_effect=SystemExit) + self.process.worker_healthy_live_signal = Mock(value=[0]) + self.process.max_chips_per_node = 8 + self.process.nnode = 1 + self.process.ranks = 1 + self.process.local_rank = 0 + self.process.task_queue = Mock() + self.process.task_queue.exist_tasks.return_value = False + self.process.task_queue.read_finish_flag = types.SimpleNamespace(get=Mock(return_value=0)) + self.process.exist_task_signal = types.SimpleNamespace(value=[0]) + self.process.worker = types.SimpleNamespace( + model_runner=types.SimpleNamespace(is_sleeping=True), + execute_model=Mock(), + exist_prefill=Mock(return_value=False), + ) + + with patch("fastdeploy.worker.worker_process.envs.FD_ENABLE_V1_UPDATE_WEIGHTS", True): + with self.assertRaises(SystemExit): + self.process.event_loop_normal() + + self.process.worker.execute_model.assert_not_called() + + if __name__ == "__main__": unittest.main()