diff --git a/brainzutils/flask/test/test_ratelimit.py b/brainzutils/flask/test/test_ratelimit.py index 3653407..35fc367 100644 --- a/brainzutils/flask/test/test_ratelimit.py +++ b/brainzutils/flask/test/test_ratelimit.py @@ -112,80 +112,18 @@ def index(): set_rate_limits(self.max_token_requests, self.max_ip_requests, self.ratelimit_window) self.make_requests(client, self.max_token_requests, token="Token %s" % valid_user) - def test_custom_ip_limit(self): - """Test that per_ip_limit parameter overrides global limit.""" - custom_limit = 2 - - @self.app.route("/custom") - @ratelimit(per_ip_limit=custom_limit, window=60) - def custom_endpoint(): - return "OK" - - client = self.app.test_client() - - # First request should succeed - response = client.get("/custom") - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers["X-RateLimit-Limit"], str(custom_limit)) - self.assertEqual(response.headers["X-RateLimit-Remaining"], "1") - - response = client.get("/custom") - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers["X-RateLimit-Remaining"], "0") - - response = client.get("/custom") - self.assertEqual(response.status_code, 429) - - def test_custom_window(self): - """Test that window parameter works correctly.""" - @self.app.route("/short-window") - @ratelimit(per_ip_limit=1, window=2) - def short_window_endpoint(): - return "OK" - - client = self.app.test_client() - - response = client.get("/short-window") - self.assertEqual(response.status_code, 200) - response = client.get("/short-window") - self.assertEqual(response.status_code, 429) - - sleep(2.5) - response = client.get("/short-window") - self.assertEqual(response.status_code, 200) - - def test_headers_contain_correct_values(self): - """Test that rate limit headers contain expected values.""" - limit = 5 - window = 30 - - @self.app.route("/headers") - @ratelimit(per_ip_limit=limit, window=window) - def headers_endpoint(): - return "OK" - - client = self.app.test_client() - response = client.get("/headers") - - self.assertEqual(response.status_code, 200) - self.assertIn("X-RateLimit-Limit", response.headers) - self.assertIn("X-RateLimit-Remaining", response.headers) - self.assertIn("X-RateLimit-Reset-In", response.headers) - - self.assertEqual(response.headers["X-RateLimit-Limit"], str(limit)) - self.assertEqual(response.headers["X-RateLimit-Remaining"], str(limit - 1)) - self.assertLessEqual(int(response.headers["X-RateLimit-Reset-In"]), window) - self.assertIn("X-RateLimit-Reset", response.headers) - def test_scope_isolation(self): """Test that different scopes have independent rate limit buckets.""" + set_rate_limits(per_token=100, per_ip=2, window=60, scope="scope_a") + set_rate_limits(per_token=100, per_ip=2, window=60, scope="scope_b") + @self.app.route("/scope-a") - @ratelimit(scope="scope_a", per_ip_limit=2, window=60) + @ratelimit(scope="scope_a") def scope_a_endpoint(): return "A" @self.app.route("/scope-b") - @ratelimit(scope="scope_b", per_ip_limit=2, window=60) + @ratelimit(scope="scope_b") def scope_b_endpoint(): return "B" @@ -245,13 +183,16 @@ def shared_2_endpoint(): def test_no_scope_vs_scoped(self): """Test that unscoped and scoped endpoints have separate buckets.""" + set_rate_limits(per_token=100, per_ip=1, window=60) + set_rate_limits(per_token=100, per_ip=1, window=60, scope="my_scope") + @self.app.route("/unscoped") - @ratelimit(per_ip_limit=1, window=60) + @ratelimit() def unscoped_endpoint(): return "Unscoped" @self.app.route("/scoped") - @ratelimit(scope="my_scope", per_ip_limit=1, window=60) + @ratelimit(scope="my_scope") def scoped_endpoint(): return "Scoped" @@ -290,31 +231,7 @@ def test_set_and_get_scope_limits(self): self.assertIsNotNone(result) self.assertEqual(result["per_token"], None) self.assertEqual(result["per_ip"], None) - self.assertEqual(result["window"], None - ) - - def test_decorator_overrides_cache_scope_limits(self): - """Test that decorator parameters override scope limits from cache.""" - scope = "override_scope" - set_rate_limits(per_token=100, per_ip=10, window=60, scope=scope) - - @self.app.route("/override") - @ratelimit(scope=scope, per_ip_limit=2) - def override_endpoint(): - return "OK" - - client = self.app.test_client() - - response = client.get("/override") - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers["X-RateLimit-Limit"], "2") - - response = client.get("/override") - self.assertEqual(response.status_code, 200) - - # 3rd request should fail (limit is 2, not 10) - response = client.get("/override") - self.assertEqual(response.status_code, 429) + self.assertEqual(result["window"], None) def test_scope_cache_values_stored_correctly(self): """Test that scope limits are stored in cache with correct keys.""" @@ -359,11 +276,6 @@ def global_endpoint(): def scope_priority_endpoint(): return "OK" - @self.app.route("/decorator-priority") - @ratelimit(scope=scope, per_ip_limit=2) - def decorator_priority_endpoint(): - return "OK" - client = self.app.test_client() response = client.get("/global") @@ -373,7 +285,3 @@ def decorator_priority_endpoint(): response = client.get("/scope-priority") self.assertEqual(response.status_code, 200) self.assertEqual(response.headers["X-RateLimit-Limit"], "3") - - response = client.get("/decorator-priority") - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers["X-RateLimit-Limit"], "2") diff --git a/brainzutils/ratelimit.py b/brainzutils/ratelimit.py index 9c910ff..bdae84f 100644 --- a/brainzutils/ratelimit.py +++ b/brainzutils/ratelimit.py @@ -78,15 +78,7 @@ def after_request_callbacks(response): def index(): return 'test' - You can also pass custom rate limit parameters directly to the decorator to override - the global/cached values:: - - @app.route('/expensive') - @ratelimit(per_token_limit=10, per_ip_limit=5, window=60) - def expensive_endpoint(): - return 'This endpoint has stricter rate limits' - - Use the scope parameter to isolate rate limits for different endpoints:: + Use the scope parameter to isolate and apply custom rate limits for different endpoints:: @app.route('/api/v1/search') @ratelimit(scope='search') @@ -94,7 +86,7 @@ def search(): return 'Search results' @app.route('/api/v1/upload') - @ratelimit(scope='upload', per_ip_limit=5, window=60) + @ratelimit(scope='upload') def upload(): return 'Upload complete' @@ -114,9 +106,8 @@ def upload(): set_rate_limits(per_token=10, per_ip=5, window=120, scope='upload') Limit resolution order (first non-None value wins): - 1. Decorator parameters (per_token_limit, per_ip_limit, window) - 2. Scope-specific limits from cache (if scope is provided) - 3. Global limits from cache + 1. Scope-specific limits from cache (if scope is provided) + 2. Global limits from cache 4. To enable token based rate limiting, callers need to pass the Authorization header (see above) and the application needs to provide a user validation function:: @@ -240,19 +231,16 @@ def on_over_limit(limit): def _get_rate_limit_helper( limit_type: Literal["per_ip", "per_token"], - _global: dict, _scope: dict = None, _local: dict = None + _global: dict, + _scope: dict = None ) -> dict: values = {} - if _local is not None and (local_limit := _local.get(limit_type)) is not None: - values["limit"] = local_limit - elif _scope is not None and (scope_limit := _scope.get(limit_type)) is not None: + if _scope is not None and (scope_limit := _scope.get(limit_type)) is not None: values["limit"] = scope_limit else: values["limit"] = _global[limit_type] - if _local is not None and (local_window := _local.get("window")) is not None: - values["window"] = local_window - elif _scope is not None and (scope_window := _scope.get("window")) is not None: + if _scope is not None and (scope_window := _scope.get("window")) is not None: values["window"] = scope_window else: values["window"] = _global["window"] @@ -260,7 +248,7 @@ def _get_rate_limit_helper( return values -def get_rate_limit_data(request, per_token_limit=None, per_ip_limit=None, window=None, scope=None): +def get_rate_limit_data(request, scope=None): """Fetch key for the given request. If an Authorization header is provided, the caller will get a better and personalized rate limit. If no header is provided, the caller will be rate limited by IP, which gets an overall lower rate limit. @@ -268,23 +256,14 @@ def get_rate_limit_data(request, per_token_limit=None, per_ip_limit=None, window Args: request: The Flask request object - per_token_limit: Optional override for per-token limit (uses cache value if None) - per_ip_limit: Optional override for per-IP limit (uses cache value if None) - window: Optional override for rate limit window in seconds (uses cache value if None) scope: Optional scope name to check for scope-specific limits in cache Limit resolution order (first non-None value wins): - 1. Decorator parameters (per_token_limit, per_ip_limit, window) - 2. Scope-specific limits from cache (if scope is provided) + 1. Scope-specific limits from cache (if scope is provided) 3. Global limits from cache """ _global = get_rate_limits() _scope = get_rate_limits(scope) if scope else None - _local = { - "per_token": per_token_limit, - "per_ip": per_ip_limit, - "window": window - } # If a user verification function is provided, parse the Authorization header and try to look up that user if ratelimit_user_validation: @@ -294,7 +273,7 @@ def get_rate_limit_data(request, per_token_limit=None, per_ip_limit=None, window is_valid = ratelimit_user_validation(auth_token) if is_valid: values = _get_rate_limit_helper( - "per_token", _global=_global, _scope=_scope, _local=_local + "per_token", _global=_global, _scope=_scope ) values["key"] = auth_token return values @@ -306,21 +285,18 @@ def get_rate_limit_data(request, per_token_limit=None, per_ip_limit=None, window ip = request.remote_addr values = _get_rate_limit_helper( - "per_ip", _global=_global, _scope=_scope, _local=_local + "per_ip", _global=_global, _scope=_scope ) values["key"] = ip return values -def ratelimit(per_token_limit=None, per_ip_limit=None, window=None, scope=None): +def ratelimit(scope=None): """ This is the decorator that should be applied to all view functions that should be rate limited. Args: - per_token_limit: Optional override for per-token limit (uses cache value if None) - per_ip_limit: Optional override for per-IP limit (uses cache value if None) - window: Optional override for rate limit window in seconds (uses cache value if None) scope: Optional scope to isolate rate limits for different endpoints. If provided, the rate limit key will be scoped with this value, allowing different endpoints to have separate rate limit buckets.. @@ -329,9 +305,6 @@ def decorator(f): def rate_limited(*args, **kwargs): data = get_rate_limit_data( request, - per_token_limit=per_token_limit, - per_ip_limit=per_ip_limit, - window=window, scope=scope ) key = f"{scope}:{data['key']}" if scope else data["key"]