diff --git a/openwisp_controller/connection/tasks.py b/openwisp_controller/connection/tasks.py index d75bbde20..886bfd9a4 100644 --- a/openwisp_controller/connection/tasks.py +++ b/openwisp_controller/connection/tasks.py @@ -16,20 +16,25 @@ _TASK_NAME = "openwisp_controller.connection.tasks.update_config" -def _is_update_in_progress(device_id): +def _is_update_in_progress(device_id, current_task_id=None): active = current_app.control.inspect().active() if not active: return False # check if there's any other running task before adding it for task_list in active.values(): for task in task_list: - if task["name"] == _TASK_NAME and str(device_id) in task["args"]: + # skip the current task itself + if current_task_id and task.get("id") == current_task_id: + continue + if task.get("name") == _TASK_NAME and str(device_id) in task.get( + "args", "" + ): return True return False -@shared_task -def update_config(device_id): +@shared_task(bind=True) +def update_config(self, device_id): """ Launches the ``update_config()`` operation of a specific device in the background @@ -48,7 +53,7 @@ def update_config(device_id): except ObjectDoesNotExist as e: logger.warning(f'update_config("{device_id}") failed: {e}') return - if _is_update_in_progress(device_id): + if _is_update_in_progress(device_id, current_task_id=self.request.id): return try: device_conn = DeviceConnection.get_working_connection(device) diff --git a/openwisp_controller/connection/tests/test_tasks.py b/openwisp_controller/connection/tests/test_tasks.py index 700dbbf32..0bda66c4b 100644 --- a/openwisp_controller/connection/tests/test_tasks.py +++ b/openwisp_controller/connection/tests/test_tasks.py @@ -9,6 +9,7 @@ from ...config.tests.test_controller import TestRegistrationMixin from .. import tasks +from ..tasks import _TASK_NAME, _is_update_in_progress from .utils import CreateConnectionsMixin Command = load_model("connection", "Command") @@ -89,6 +90,98 @@ def test_launch_command_exception(self, *args): self.assertEqual(command.output, "Internal system error: test error\n") +class TestIsUpdateInProgress(CreateConnectionsMixin, TestCase): + + def _get_mocked_active_tasks(self, device_id, task_id="task-123"): + return { + "celery@worker1": [ + { + "id": task_id, + "name": _TASK_NAME, + "args": f"('{device_id}',)", + } + ] + } + + @mock.patch("openwisp_controller.connection.tasks.current_app") + def test_is_update_in_progress_without_current_task_id(self, mock_app): + + device_id = uuid.uuid4() + current_task_id = "task-123" + + mock_app.control.inspect.return_value.active.return_value = ( + self._get_mocked_active_tasks(device_id, task_id=current_task_id) + ) + + # BUG: Without passing current_task_id, the function returns True + # even though the only active task IS the current task + result = _is_update_in_progress(device_id) + self.assertTrue(result) + + @mock.patch("openwisp_controller.connection.tasks.current_app") + def test_is_update_in_progress_with_current_task_id_excluded(self, mock_app): + + device_id = uuid.uuid4() + current_task_id = "task-123" + + mock_app.control.inspect.return_value.active.return_value = ( + self._get_mocked_active_tasks(device_id, task_id=current_task_id) + ) + + # FIX: With current_task_id provided, the function correctly returns False + result = _is_update_in_progress(device_id, current_task_id=current_task_id) + self.assertFalse(result) + + @mock.patch("openwisp_controller.connection.tasks.current_app") + def test_is_update_in_progress_detects_another_task(self, mock_app): + + device_id = uuid.uuid4() + current_task_id = "task-123" + another_task_id = "task-456" + + # Mock active tasks with both current task and another task + mock_app.control.inspect.return_value.active.return_value = { + "celery@worker1": [ + { + "id": current_task_id, + "name": _TASK_NAME, + "args": f"('{device_id}',)", + }, + { + "id": another_task_id, + "name": _TASK_NAME, + "args": f"('{device_id}',)", + }, + ] + } + + # Should return True because another task IS running + result = _is_update_in_progress(device_id, current_task_id=current_task_id) + self.assertTrue(result) + + @mock.patch("openwisp_controller.connection.tasks.current_app") + def test_is_update_in_progress_no_active_tasks(self, mock_app): + + device_id = uuid.uuid4() + mock_app.control.inspect.return_value.active.return_value = None + + result = _is_update_in_progress(device_id, current_task_id="task-123") + self.assertFalse(result) + + @mock.patch("openwisp_controller.connection.tasks.current_app") + def test_is_update_in_progress_different_device(self, mock_app): + + device_id = uuid.uuid4() + other_device_id = uuid.uuid4() + + mock_app.control.inspect.return_value.active.return_value = ( + self._get_mocked_active_tasks(other_device_id, task_id="task-456") + ) + + result = _is_update_in_progress(device_id, current_task_id="task-123") + self.assertFalse(result) + + class TestTransactionTasks( TestRegistrationMixin, CreateConnectionsMixin, TransactionTestCase ):