From 97484bc095f97c851a3c2bf774d639e70afc4237 Mon Sep 17 00:00:00 2001 From: Kartik Ohri Date: Tue, 16 Dec 2025 20:20:15 +0530 Subject: [PATCH] Remove decorator override for ratelimit The ratelimit keys in the cache, include a global or scope prefix in the key name. At the decorator level, there is no prefix, and hence it would use the global prefix and end up overriding global ratelimit records for the requests made by the user in the current window. Thinking more of it, I don't think we need decorator level overrides. Scope isolation should be sufficient. --- brainzutils/flask/test/test_ratelimit.py | 114 +++-------------------- brainzutils/ratelimit.py | 53 +++-------- 2 files changed, 24 insertions(+), 143 deletions(-) diff --git a/brainzutils/flask/test/test_ratelimit.py b/brainzutils/flask/test/test_ratelimit.py index 3653407..35fc367 100644 --- a/brainzutils/flask/test/test_ratelimit.py +++ b/brainzutils/flask/test/test_ratelimit.py @@ -112,80 +112,18 @@ def index(): set_rate_limits(self.max_token_requests, self.max_ip_requests, self.ratelimit_window) self.make_requests(client, self.max_token_requests, token="Token %s" % valid_user) - def test_custom_ip_limit(self): - """Test that per_ip_limit parameter overrides global limit.""" - custom_limit = 2 - - @self.app.route("/custom") - @ratelimit(per_ip_limit=custom_limit, window=60) - def custom_endpoint(): - return "OK" - - client = self.app.test_client() - - # First request should succeed - response = client.get("/custom") - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers["X-RateLimit-Limit"], str(custom_limit)) - self.assertEqual(response.headers["X-RateLimit-Remaining"], "1") - - response = client.get("/custom") - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers["X-RateLimit-Remaining"], "0") - - response = client.get("/custom") - self.assertEqual(response.status_code, 429) - - def test_custom_window(self): - """Test that window parameter works correctly.""" - @self.app.route("/short-window") - @ratelimit(per_ip_limit=1, window=2) - def short_window_endpoint(): - return "OK" - - client = self.app.test_client() - - response = client.get("/short-window") - self.assertEqual(response.status_code, 200) - response = client.get("/short-window") - self.assertEqual(response.status_code, 429) - - sleep(2.5) - response = client.get("/short-window") - self.assertEqual(response.status_code, 200) - - def test_headers_contain_correct_values(self): - """Test that rate limit headers contain expected values.""" - limit = 5 - window = 30 - - @self.app.route("/headers") - @ratelimit(per_ip_limit=limit, window=window) - def headers_endpoint(): - return "OK" - - client = self.app.test_client() - response = client.get("/headers") - - self.assertEqual(response.status_code, 200) - self.assertIn("X-RateLimit-Limit", response.headers) - self.assertIn("X-RateLimit-Remaining", response.headers) - self.assertIn("X-RateLimit-Reset-In", response.headers) - - self.assertEqual(response.headers["X-RateLimit-Limit"], str(limit)) - self.assertEqual(response.headers["X-RateLimit-Remaining"], str(limit - 1)) - self.assertLessEqual(int(response.headers["X-RateLimit-Reset-In"]), window) - self.assertIn("X-RateLimit-Reset", response.headers) - def test_scope_isolation(self): """Test that different scopes have independent rate limit buckets.""" + set_rate_limits(per_token=100, per_ip=2, window=60, scope="scope_a") + set_rate_limits(per_token=100, per_ip=2, window=60, scope="scope_b") + @self.app.route("/scope-a") - @ratelimit(scope="scope_a", per_ip_limit=2, window=60) + @ratelimit(scope="scope_a") def scope_a_endpoint(): return "A" @self.app.route("/scope-b") - @ratelimit(scope="scope_b", per_ip_limit=2, window=60) + @ratelimit(scope="scope_b") def scope_b_endpoint(): return "B" @@ -245,13 +183,16 @@ def shared_2_endpoint(): def test_no_scope_vs_scoped(self): """Test that unscoped and scoped endpoints have separate buckets.""" + set_rate_limits(per_token=100, per_ip=1, window=60) + set_rate_limits(per_token=100, per_ip=1, window=60, scope="my_scope") + @self.app.route("/unscoped") - @ratelimit(per_ip_limit=1, window=60) + @ratelimit() def unscoped_endpoint(): return "Unscoped" @self.app.route("/scoped") - @ratelimit(scope="my_scope", per_ip_limit=1, window=60) + @ratelimit(scope="my_scope") def scoped_endpoint(): return "Scoped" @@ -290,31 +231,7 @@ def test_set_and_get_scope_limits(self): self.assertIsNotNone(result) self.assertEqual(result["per_token"], None) self.assertEqual(result["per_ip"], None) - self.assertEqual(result["window"], None - ) - - def test_decorator_overrides_cache_scope_limits(self): - """Test that decorator parameters override scope limits from cache.""" - scope = "override_scope" - set_rate_limits(per_token=100, per_ip=10, window=60, scope=scope) - - @self.app.route("/override") - @ratelimit(scope=scope, per_ip_limit=2) - def override_endpoint(): - return "OK" - - client = self.app.test_client() - - response = client.get("/override") - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers["X-RateLimit-Limit"], "2") - - response = client.get("/override") - self.assertEqual(response.status_code, 200) - - # 3rd request should fail (limit is 2, not 10) - response = client.get("/override") - self.assertEqual(response.status_code, 429) + self.assertEqual(result["window"], None) def test_scope_cache_values_stored_correctly(self): """Test that scope limits are stored in cache with correct keys.""" @@ -359,11 +276,6 @@ def global_endpoint(): def scope_priority_endpoint(): return "OK" - @self.app.route("/decorator-priority") - @ratelimit(scope=scope, per_ip_limit=2) - def decorator_priority_endpoint(): - return "OK" - client = self.app.test_client() response = client.get("/global") @@ -373,7 +285,3 @@ def decorator_priority_endpoint(): response = client.get("/scope-priority") self.assertEqual(response.status_code, 200) self.assertEqual(response.headers["X-RateLimit-Limit"], "3") - - response = client.get("/decorator-priority") - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers["X-RateLimit-Limit"], "2") diff --git a/brainzutils/ratelimit.py b/brainzutils/ratelimit.py index 9c910ff..bdae84f 100644 --- a/brainzutils/ratelimit.py +++ b/brainzutils/ratelimit.py @@ -78,15 +78,7 @@ def after_request_callbacks(response): def index(): return 'test' - You can also pass custom rate limit parameters directly to the decorator to override - the global/cached values:: - - @app.route('/expensive') - @ratelimit(per_token_limit=10, per_ip_limit=5, window=60) - def expensive_endpoint(): - return 'This endpoint has stricter rate limits' - - Use the scope parameter to isolate rate limits for different endpoints:: + Use the scope parameter to isolate and apply custom rate limits for different endpoints:: @app.route('/api/v1/search') @ratelimit(scope='search') @@ -94,7 +86,7 @@ def search(): return 'Search results' @app.route('/api/v1/upload') - @ratelimit(scope='upload', per_ip_limit=5, window=60) + @ratelimit(scope='upload') def upload(): return 'Upload complete' @@ -114,9 +106,8 @@ def upload(): set_rate_limits(per_token=10, per_ip=5, window=120, scope='upload') Limit resolution order (first non-None value wins): - 1. Decorator parameters (per_token_limit, per_ip_limit, window) - 2. Scope-specific limits from cache (if scope is provided) - 3. Global limits from cache + 1. Scope-specific limits from cache (if scope is provided) + 2. Global limits from cache 4. To enable token based rate limiting, callers need to pass the Authorization header (see above) and the application needs to provide a user validation function:: @@ -240,19 +231,16 @@ def on_over_limit(limit): def _get_rate_limit_helper( limit_type: Literal["per_ip", "per_token"], - _global: dict, _scope: dict = None, _local: dict = None + _global: dict, + _scope: dict = None ) -> dict: values = {} - if _local is not None and (local_limit := _local.get(limit_type)) is not None: - values["limit"] = local_limit - elif _scope is not None and (scope_limit := _scope.get(limit_type)) is not None: + if _scope is not None and (scope_limit := _scope.get(limit_type)) is not None: values["limit"] = scope_limit else: values["limit"] = _global[limit_type] - if _local is not None and (local_window := _local.get("window")) is not None: - values["window"] = local_window - elif _scope is not None and (scope_window := _scope.get("window")) is not None: + if _scope is not None and (scope_window := _scope.get("window")) is not None: values["window"] = scope_window else: values["window"] = _global["window"] @@ -260,7 +248,7 @@ def _get_rate_limit_helper( return values -def get_rate_limit_data(request, per_token_limit=None, per_ip_limit=None, window=None, scope=None): +def get_rate_limit_data(request, scope=None): """Fetch key for the given request. If an Authorization header is provided, the caller will get a better and personalized rate limit. If no header is provided, the caller will be rate limited by IP, which gets an overall lower rate limit. @@ -268,23 +256,14 @@ def get_rate_limit_data(request, per_token_limit=None, per_ip_limit=None, window Args: request: The Flask request object - per_token_limit: Optional override for per-token limit (uses cache value if None) - per_ip_limit: Optional override for per-IP limit (uses cache value if None) - window: Optional override for rate limit window in seconds (uses cache value if None) scope: Optional scope name to check for scope-specific limits in cache Limit resolution order (first non-None value wins): - 1. Decorator parameters (per_token_limit, per_ip_limit, window) - 2. Scope-specific limits from cache (if scope is provided) + 1. Scope-specific limits from cache (if scope is provided) 3. Global limits from cache """ _global = get_rate_limits() _scope = get_rate_limits(scope) if scope else None - _local = { - "per_token": per_token_limit, - "per_ip": per_ip_limit, - "window": window - } # If a user verification function is provided, parse the Authorization header and try to look up that user if ratelimit_user_validation: @@ -294,7 +273,7 @@ def get_rate_limit_data(request, per_token_limit=None, per_ip_limit=None, window is_valid = ratelimit_user_validation(auth_token) if is_valid: values = _get_rate_limit_helper( - "per_token", _global=_global, _scope=_scope, _local=_local + "per_token", _global=_global, _scope=_scope ) values["key"] = auth_token return values @@ -306,21 +285,18 @@ def get_rate_limit_data(request, per_token_limit=None, per_ip_limit=None, window ip = request.remote_addr values = _get_rate_limit_helper( - "per_ip", _global=_global, _scope=_scope, _local=_local + "per_ip", _global=_global, _scope=_scope ) values["key"] = ip return values -def ratelimit(per_token_limit=None, per_ip_limit=None, window=None, scope=None): +def ratelimit(scope=None): """ This is the decorator that should be applied to all view functions that should be rate limited. Args: - per_token_limit: Optional override for per-token limit (uses cache value if None) - per_ip_limit: Optional override for per-IP limit (uses cache value if None) - window: Optional override for rate limit window in seconds (uses cache value if None) scope: Optional scope to isolate rate limits for different endpoints. If provided, the rate limit key will be scoped with this value, allowing different endpoints to have separate rate limit buckets.. @@ -329,9 +305,6 @@ def decorator(f): def rate_limited(*args, **kwargs): data = get_rate_limit_data( request, - per_token_limit=per_token_limit, - per_ip_limit=per_ip_limit, - window=window, scope=scope ) key = f"{scope}:{data['key']}" if scope else data["key"]