diff --git a/runpod/serverless/modules/rp_scale.py b/runpod/serverless/modules/rp_scale.py index f8a63bca..a97ab521 100644 --- a/runpod/serverless/modules/rp_scale.py +++ b/runpod/serverless/modules/rp_scale.py @@ -72,7 +72,19 @@ def __init__(self, config: Dict[str, Any]): self.jobs_handler = jobs_handler async def set_scale(self): - self.current_concurrency = self.concurrency_modifier(self.current_concurrency) + # Concurrency modifier is user-provided and can return invalid values (e.g. None). + # Defensive validation prevents crashes like: TypeError: '<' not supported between 'int' and 'NoneType' + # when current_concurrency is used for queue sizing / task scheduling. + try: + new_concurrency = self.concurrency_modifier(self.current_concurrency) + except Exception as error: + log.warn( + f"JobScaler.set_scale | concurrency_modifier raised {type(error).__name__}: {error}. " + f"Keeping concurrency at {self.current_concurrency}." + ) + new_concurrency = self.current_concurrency + + self.current_concurrency = self._sanitize_concurrency(new_concurrency) if self.jobs_queue and (self.current_concurrency == self.jobs_queue.maxsize): # no need to resize @@ -88,6 +100,34 @@ async def set_scale(self): f"JobScaler.set_scale | New concurrency set to: {self.current_concurrency}" ) + @staticmethod + def _sanitize_concurrency(value: Any) -> int: + """ + Coerce a user-provided concurrency value into a safe integer >= 1. + """ + # Reject common footguns explicitly. + if value is None or isinstance(value, bool) or isinstance(value, float): + log.warn( + f"JobScaler.set_scale | Invalid concurrency value: {value!r}. Defaulting to 1." + ) + return 1 + + try: + v = int(value) + except Exception: + log.warn( + f"JobScaler.set_scale | Invalid concurrency value: {value!r}. Defaulting to 1." + ) + return 1 + + if v < 1: + log.warn( + f"JobScaler.set_scale | Invalid concurrency value: {value!r}. Defaulting to 1." + ) + return 1 + + return v + def start(self): """ This is required for the worker to be able to shut down gracefully diff --git a/tests/test_serverless/test_modules/test_rp_scale_concurrency_validation.py b/tests/test_serverless/test_modules/test_rp_scale_concurrency_validation.py new file mode 100644 index 00000000..efa0530d --- /dev/null +++ b/tests/test_serverless/test_modules/test_rp_scale_concurrency_validation.py @@ -0,0 +1,27 @@ +import asyncio +from unittest import TestCase + +from runpod.serverless.modules.rp_scale import JobScaler + + +class TestJobScalerConcurrencyValidation(TestCase): + def test_concurrency_modifier_none_defaults_to_one(self): + scaler = JobScaler({"concurrency_modifier": lambda _: None}) + asyncio.run(scaler.set_scale()) + self.assertEqual(scaler.current_concurrency, 1) + + def test_concurrency_modifier_zero_defaults_to_one(self): + scaler = JobScaler({"concurrency_modifier": lambda _: 0}) + asyncio.run(scaler.set_scale()) + self.assertEqual(scaler.current_concurrency, 1) + + def test_concurrency_modifier_negative_defaults_to_one(self): + scaler = JobScaler({"concurrency_modifier": lambda _: -3}) + asyncio.run(scaler.set_scale()) + self.assertEqual(scaler.current_concurrency, 1) + + def test_concurrency_modifier_valid_int_is_applied(self): + scaler = JobScaler({"concurrency_modifier": lambda _: 4}) + asyncio.run(scaler.set_scale()) + self.assertEqual(scaler.current_concurrency, 4) +