From d174d8d8999b24c3716159297ea8ad913c1787d6 Mon Sep 17 00:00:00 2001 From: Giacomo Marciani Date: Wed, 6 May 2026 18:54:48 -0400 Subject: [PATCH] [clustermgtd] Add retry to compute fleet status retrieval in clustermgtd. This retry mitigates the impact of networking glitches, such as transient unavailability of IMDS. --- CHANGELOG.md | 3 +++ src/slurm_plugin/clustermgtd.py | 16 ++++++++++------ tests/slurm_plugin/test_clustermgtd.py | 6 +++++- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a370dfe3..10fc97e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,9 @@ This file is used to list changes made in each version of the aws-parallelcluste 3.16.0 ------ +**ENHANCEMENTS** +- Mitigate the impact of transient IMDS failures during compute fleet status retrieval. + **BUG FIXES** - Fix clustermgtd failing to detect compute node bootstrap timeouts, which prevented the cluster from entering protected mode. diff --git a/src/slurm_plugin/clustermgtd.py b/src/slurm_plugin/clustermgtd.py index c612dea1..7fcd4ace 100644 --- a/src/slurm_plugin/clustermgtd.py +++ b/src/slurm_plugin/clustermgtd.py @@ -97,14 +97,19 @@ class ComputeFleetStatusManager: COMPUTE_FLEET_STATUS_ATTRIBUTE = "status" COMPUTE_FLEET_LAST_UPDATED_TIME_ATTRIBUTE = "lastStatusUpdatedTime" + @staticmethod + @retry(stop_max_attempt_number=3, wait_fixed=seconds(1)) + def _get_fleet_status(): + compute_fleet_raw_data = check_command_output("get-compute-fleet-status.sh") + log.debug("Retrieved compute fleet data: %s", compute_fleet_raw_data) + return ComputeFleetStatus( + json.loads(compute_fleet_raw_data).get(ComputeFleetStatusManager.COMPUTE_FLEET_STATUS_ATTRIBUTE) + ) + @staticmethod def get_status(fallback=None): try: - compute_fleet_raw_data = check_command_output("get-compute-fleet-status.sh") - log.debug("Retrieved compute fleet data: %s", compute_fleet_raw_data) - return ComputeFleetStatus( - json.loads(compute_fleet_raw_data).get(ComputeFleetStatusManager.COMPUTE_FLEET_STATUS_ATTRIBUTE) - ) + return ComputeFleetStatusManager._get_fleet_status() except Exception as e: if isinstance(e, CalledProcessError): error = e.stdout.rstrip() @@ -115,7 +120,6 @@ def get_status(fallback=None): error, fallback, ) - return fallback @staticmethod diff --git a/tests/slurm_plugin/test_clustermgtd.py b/tests/slurm_plugin/test_clustermgtd.py index 72a9f8b0..95ffb0a5 100644 --- a/tests/slurm_plugin/test_clustermgtd.py +++ b/tests/slurm_plugin/test_clustermgtd.py @@ -2304,6 +2304,7 @@ class TestComputeFleetStatusManager: ) def test_get_status(self, mocker, get_item_response, fallback, expected_status): check_command_output_mocked = mocker.patch("slurm_plugin.clustermgtd.check_command_output", autospec=True) + mocker.patch("retrying.time.sleep") compute_fleet_status_manager = ComputeFleetStatusManager() if get_item_response is Exception: @@ -2312,7 +2313,10 @@ def test_get_status(self, mocker, get_item_response, fallback, expected_status): check_command_output_mocked.return_value = get_item_response status = compute_fleet_status_manager.get_status(fallback) assert_that(status).is_equal_to(expected_status) - check_command_output_mocked.assert_called_once_with("get-compute-fleet-status.sh") + if get_item_response is Exception or get_item_response == "": + assert_that(check_command_output_mocked.call_count).is_equal_to(3) + else: + check_command_output_mocked.assert_called_once_with("get-compute-fleet-status.sh") @pytest.mark.parametrize( "desired_status, update_item_response",