Skip to content
10 changes: 8 additions & 2 deletions openwisp_controller/connection/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time

import swapper
from celery import current_app, shared_task
from celery import current_app, current_task, shared_task
from celery.exceptions import SoftTimeLimitExceeded
from django.core.exceptions import ObjectDoesNotExist
from django.utils.translation import gettext_lazy as _
Expand All @@ -20,11 +20,17 @@ def _is_update_in_progress(device_id):
active = current_app.control.inspect().active()
if not active:
return False
current_task_id = getattr(current_task, 'request', None)
if current_task_id:
current_task_id = current_task_id.id
else:
current_task_id = None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else:
current_task_id = None

Not needed

# check if there's any other running task before adding it
for task_list in active.values():
for task in task_list:
if task["name"] == _TASK_NAME and str(device_id) in task["args"]:
return True
if task.get("id") != current_task_id:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid nesting this and add it to the previous if, we should probably rewrite this using all() https://docs.python.org/3/library/functions.html#all

return True
return False


Expand Down
78 changes: 78 additions & 0 deletions openwisp_controller/connection/tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,84 @@ def test_launch_command_exception(self, *args):
self.assertEqual(command.output, "Internal system error: test error\n")


class TestIsUpdateInProgress(CreateConnectionsMixin, TestCase):
@mock.patch("openwisp_controller.connection.tasks.current_task")
@mock.patch("openwisp_controller.connection.tasks.current_app")
def test_is_update_in_progress_same_worker(
self, mocked_current_app, mocked_current_task
):
device_id = 1
mocked_current_task.request.id = "task123"
mocked_inspect = mock.Mock()
mocked_current_app.control.inspect.return_value = mocked_inspect
mocked_inspect.active.return_value = {
"worker1": [
{
"name": "openwisp_controller.connection.tasks.update_config",
"args": ["1"],
"id": "task123",
}
]
}
result = tasks._is_update_in_progress(device_id)
self.assertFalse(result)

@mock.patch("openwisp_controller.connection.tasks.current_task")
@mock.patch("openwisp_controller.connection.tasks.current_app")
def test_is_update_in_progress_different_worker(
self, mocked_current_app, mocked_current_task
):
device_id = 1
mocked_current_task.request.id = "task123"
mocked_inspect = mock.Mock()
mocked_current_app.control.inspect.return_value = mocked_inspect
mocked_inspect.active.return_value = {
"worker2": [
{
"name": "openwisp_controller.connection.tasks.update_config",
"args": ["1"],
"id": "task456",
}
]
}
result = tasks._is_update_in_progress(device_id)
self.assertTrue(result)

@mock.patch("openwisp_controller.connection.tasks.current_task")
@mock.patch("openwisp_controller.connection.tasks.current_app")
def test_is_update_in_progress_no_active_tasks(
self, mocked_current_app, mocked_current_task
):
device_id = 1
mocked_current_task.request.id = "task123"
mocked_inspect = mock.Mock()
mocked_current_app.control.inspect.return_value = mocked_inspect
mocked_inspect.active.return_value = {}
result = tasks._is_update_in_progress(device_id)
self.assertFalse(result)

@mock.patch("openwisp_controller.connection.tasks.current_task")
@mock.patch("openwisp_controller.connection.tasks.current_app")
def test_is_update_in_progress_different_device(
self, mocked_current_app, mocked_current_task
):
device_id = 1
mocked_current_task.request.id = "task123"
mocked_inspect = mock.Mock()
mocked_current_app.control.inspect.return_value = mocked_inspect
mocked_inspect.active.return_value = {
"worker1": [
{
"name": "openwisp_controller.connection.tasks.update_config",
"args": ["2"],
"id": "task456",
}
]
}
result = tasks._is_update_in_progress(device_id)
self.assertFalse(result)


class TestTransactionTasks(
TestRegistrationMixin, CreateConnectionsMixin, TransactionTestCase
):
Expand Down
Loading