From 7ade1a425ad1439db0d45c6565c64f746db2ece8 Mon Sep 17 00:00:00 2001 From: Kartik Ohri Date: Thu, 11 Dec 2025 01:01:23 +0530 Subject: [PATCH] Add scoped and per-endpoint ratelimit customizations - Add optional parameters to ratelimit decorator: per_token_limit, per_ip_limit, window, and scope - Add scope parameter to customize rate limits for a bundle of endpoints - Implement limit resolution order: decorator params > scope cache > global --- brainzutils/flask/test/test_ratelimit.py | 404 +++++++++++++++++++---- brainzutils/ratelimit.py | 272 +++++++++------ 2 files changed, 510 insertions(+), 166 deletions(-) diff --git a/brainzutils/flask/test/test_ratelimit.py b/brainzutils/flask/test/test_ratelimit.py index ba94e67..3653407 100644 --- a/brainzutils/flask/test/test_ratelimit.py +++ b/brainzutils/flask/test/test_ratelimit.py @@ -3,7 +3,15 @@ from time import sleep from brainzutils import flask, cache -from brainzutils.ratelimit import ratelimit, set_rate_limits, inject_x_rate_headers, set_user_validation_function +from brainzutils.ratelimit import ( + ratelimit, + set_rate_limits, + get_rate_limits, + inject_x_rate_headers, + set_user_validation_function, + ratelimit_cache_namespace, +) + valid_user = "41FB6EEB-636B-4F7C-B376-3A8613F1E69A" def validate_user(user): @@ -11,6 +19,7 @@ def validate_user(user): return True return False + class RatelimitTestCase(unittest.TestCase): host = os.environ.get("REDIS_HOST", "localhost") @@ -29,87 +38,342 @@ def setUp(self): # Making sure there are no items in cache before we run each test cache.flush_all() - def test_create_app(self): - app = flask.CustomFlask(__name__) - self.assertIsNotNone(app) - - def test_ratelimit(self): - """ Tests that the ratelimit decorator works - """ - - # Set the limits as per defines in this class set_rate_limits(self.max_token_requests, self.max_ip_requests, self.ratelimit_window) + self.app = flask.CustomFlask(__name__) + self.app.debug = True + self.app.config["SECRET_KEY"] = "this is a totally secret key btw" - # create an app - app = flask.CustomFlask(__name__) - self.assertIsNotNone(app) - app.debug = True - app.config['SECRET_KEY'] = 'this is a totally secret key btw' - app.init_debug_toolbar() - - @app.after_request + @self.app.after_request def after_request_callbacks(response): return inject_x_rate_headers(response) - # add a dummy route - @app.route('/') + def tearDown(self): + self.app = None + + def print_headers(self, response): + print("X-RateLimit-Remaining", response.headers["X-RateLimit-Remaining"]) + print("X-RateLimit-Limit", response.headers["X-RateLimit-Limit"]) + print("X-RateLimit-Reset", response.headers["X-RateLimit-Reset"]) + print("X-RateLimit-Reset-In", response.headers["X-RateLimit-Reset-In"]) + print() + + def make_requests(self, client, nominal_num_requests, token = None): + print("===== make %d requests" % nominal_num_requests) + # make one more than the allowed number of requests to catch the 429 + num_requests = nominal_num_requests + 1 + + # make a specified number of requests + while True: + reset_time = 0 + restart = False + for i in range(num_requests): + if token: + response = client.get("/", headers={"Authorization": token}) + else: + response = client.get("/") + if reset_time == 0: + reset_time = response.headers["X-RateLimit-Reset"] + + if reset_time != response.headers["X-RateLimit-Reset"]: + # Whoops, we didn"t get our tests done before the window expired. start over. + restart = True + + # when restarting we need to do one request less, since the current requests counts to the new window + num_requests = nominal_num_requests + break + + if i == num_requests - 1: + self.assertEqual(response.status_code, 429) + else: + self.assertEqual(response.status_code, 200) + self.assertEqual(int(response.headers["X-RateLimit-Remaining"]), num_requests - i - 2) + self.print_headers(response) + + sleep(1.1) + + if not restart: + break + + def test_ratelimit(self): + """ Tests that the ratelimit decorator works """ + @self.app.route("/") @ratelimit() def index(): - return 'test' - - def print_headers(response): - print("X-RateLimit-Remaining", response.headers['X-RateLimit-Remaining']) - print("X-RateLimit-Limit", response.headers['X-RateLimit-Limit']) - print("X-RateLimit-Reset", response.headers['X-RateLimit-Reset']) - print("X-RateLimit-Reset-In", response.headers['X-RateLimit-Reset-In']) - print() - - - def make_requests(client, nominal_num_requests, token = None): - - print("===== make %d requests" % nominal_num_requests) - # make one more than the allowed number of requests to catch the 429 - num_requests = nominal_num_requests + 1 - - # make a specified number of requests - while True: - reset_time = 0 - restart = False - for i in range(num_requests): - if token: - response = client.get('/', headers={'Authorization': token}) - else: - response = client.get('/') - if reset_time == 0: - reset_time = response.headers['X-RateLimit-Reset'] - - if reset_time != response.headers['X-RateLimit-Reset']: - # Whoops, we didn't get our tests done before the window expired. start over. - restart = True - - # when restarting we need to do one request less, since the current requests counts to the new window - num_requests = nominal_num_requests - break - - if i == num_requests - 1: - self.assertEqual(response.status_code, 429) - else: - self.assertEqual(response.status_code, 200) - self.assertEqual(int(response.headers['X-RateLimit-Remaining']), num_requests - i - 2) - print_headers(response) - - sleep(1.1) - - if not restart: - break + return "test" - client = app.test_client() + client = self.app.test_client() # Make a pile of requests based on IP address - make_requests(client, self.max_ip_requests) + self.make_requests(client, self.max_ip_requests) # Set a user token and make requests based on the token cache.flush_all() set_user_validation_function(validate_user) set_rate_limits(self.max_token_requests, self.max_ip_requests, self.ratelimit_window) - make_requests(client, self.max_token_requests, token="Token %s" % valid_user) + self.make_requests(client, self.max_token_requests, token="Token %s" % valid_user) + + def test_custom_ip_limit(self): + """Test that per_ip_limit parameter overrides global limit.""" + custom_limit = 2 + + @self.app.route("/custom") + @ratelimit(per_ip_limit=custom_limit, window=60) + def custom_endpoint(): + return "OK" + + client = self.app.test_client() + + # First request should succeed + response = client.get("/custom") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.headers["X-RateLimit-Limit"], str(custom_limit)) + self.assertEqual(response.headers["X-RateLimit-Remaining"], "1") + + response = client.get("/custom") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.headers["X-RateLimit-Remaining"], "0") + + response = client.get("/custom") + self.assertEqual(response.status_code, 429) + + def test_custom_window(self): + """Test that window parameter works correctly.""" + @self.app.route("/short-window") + @ratelimit(per_ip_limit=1, window=2) + def short_window_endpoint(): + return "OK" + + client = self.app.test_client() + + response = client.get("/short-window") + self.assertEqual(response.status_code, 200) + response = client.get("/short-window") + self.assertEqual(response.status_code, 429) + + sleep(2.5) + response = client.get("/short-window") + self.assertEqual(response.status_code, 200) + + def test_headers_contain_correct_values(self): + """Test that rate limit headers contain expected values.""" + limit = 5 + window = 30 + + @self.app.route("/headers") + @ratelimit(per_ip_limit=limit, window=window) + def headers_endpoint(): + return "OK" + + client = self.app.test_client() + response = client.get("/headers") + + self.assertEqual(response.status_code, 200) + self.assertIn("X-RateLimit-Limit", response.headers) + self.assertIn("X-RateLimit-Remaining", response.headers) + self.assertIn("X-RateLimit-Reset-In", response.headers) + + self.assertEqual(response.headers["X-RateLimit-Limit"], str(limit)) + self.assertEqual(response.headers["X-RateLimit-Remaining"], str(limit - 1)) + self.assertLessEqual(int(response.headers["X-RateLimit-Reset-In"]), window) + self.assertIn("X-RateLimit-Reset", response.headers) + + def test_scope_isolation(self): + """Test that different scopes have independent rate limit buckets.""" + @self.app.route("/scope-a") + @ratelimit(scope="scope_a", per_ip_limit=2, window=60) + def scope_a_endpoint(): + return "A" + + @self.app.route("/scope-b") + @ratelimit(scope="scope_b", per_ip_limit=2, window=60) + def scope_b_endpoint(): + return "B" + + client = self.app.test_client() + + # Exhaust scope_a limit + response = client.get("/scope-a") + self.assertEqual(response.status_code, 200) + response = client.get("/scope-a") + self.assertEqual(response.status_code, 200) + response = client.get("/scope-a") + self.assertEqual(response.status_code, 429) + + # scope_b should still work independently + response = client.get("/scope-b") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.headers["X-RateLimit-Remaining"], "1") + + response = client.get("/scope-b") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.headers["X-RateLimit-Remaining"], "0") + response = client.get("/scope-b") + self.assertEqual(response.status_code, 429) + + def test_same_scope_shared_limit(self): + """Test that endpoints with the same scope share rate limit bucket.""" + scope = "shared" + set_rate_limits(per_token=100, per_ip=2, window=60, scope=scope) + + @self.app.route("/shared-1") + @ratelimit(scope=scope) + def shared_1_endpoint(): + return "Shared 1" + + @self.app.route("/shared-2") + @ratelimit(scope=scope) + def shared_2_endpoint(): + return "Shared 2" + + client = self.app.test_client() + + # Make request to shared-1 + response = client.get("/shared-1") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.headers["X-RateLimit-Remaining"], "1") + + # Make request to shared-2 (should share the count) + response = client.get("/shared-2") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.headers["X-RateLimit-Remaining"], "0") + + # Both should now be rate limited + response = client.get("/shared-1") + self.assertEqual(response.status_code, 429) + response = client.get("/shared-2") + self.assertEqual(response.status_code, 429) + + def test_no_scope_vs_scoped(self): + """Test that unscoped and scoped endpoints have separate buckets.""" + @self.app.route("/unscoped") + @ratelimit(per_ip_limit=1, window=60) + def unscoped_endpoint(): + return "Unscoped" + + @self.app.route("/scoped") + @ratelimit(scope="my_scope", per_ip_limit=1, window=60) + def scoped_endpoint(): + return "Scoped" + + client = self.app.test_client() + + # Exhaust unscoped limit + response = client.get("/unscoped") + self.assertEqual(response.status_code, 200) + response = client.get("/unscoped") + self.assertEqual(response.status_code, 429) + + # Scoped endpoint should still work + response = client.get("/scoped") + self.assertEqual(response.status_code, 200) + + # Now exhaust scoped + response = client.get("/scoped") + self.assertEqual(response.status_code, 429) + + def test_set_and_get_scope_limits(self): + """Test that set_rate_limits_for_scope and get_rate_limits work.""" + scope = "test_scope" + per_token = 100 + per_ip = 50 + window = 120 + + set_rate_limits(per_token, per_ip, window, scope=scope) + + result = get_rate_limits(scope) + self.assertIsNotNone(result) + self.assertEqual(result["per_token"], per_token) + self.assertEqual(result["per_ip"], per_ip) + self.assertEqual(result["window"], window) + + result = get_rate_limits("nonexistent_scope") + self.assertIsNotNone(result) + self.assertEqual(result["per_token"], None) + self.assertEqual(result["per_ip"], None) + self.assertEqual(result["window"], None + ) + + def test_decorator_overrides_cache_scope_limits(self): + """Test that decorator parameters override scope limits from cache.""" + scope = "override_scope" + set_rate_limits(per_token=100, per_ip=10, window=60, scope=scope) + + @self.app.route("/override") + @ratelimit(scope=scope, per_ip_limit=2) + def override_endpoint(): + return "OK" + + client = self.app.test_client() + + response = client.get("/override") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.headers["X-RateLimit-Limit"], "2") + + response = client.get("/override") + self.assertEqual(response.status_code, 200) + + # 3rd request should fail (limit is 2, not 10) + response = client.get("/override") + self.assertEqual(response.status_code, 429) + + def test_scope_cache_values_stored_correctly(self): + """Test that scope limits are stored in cache with correct keys.""" + scope = "verify_cache" + per_token = 50 + per_ip = 25 + window = 30 + + set_rate_limits(per_token, per_ip, window, scope=scope) + + # Verify the values are in cache with correct keys + stored_per_token = cache.get( + f"{scope}:rate_limit_per_token_limit", + namespace=ratelimit_cache_namespace + ) + stored_per_ip = cache.get( + f"{scope}:rate_limit_per_ip_limit", + namespace=ratelimit_cache_namespace + ) + stored_window = cache.get( + f"{scope}:rate_limit_window", + namespace=ratelimit_cache_namespace + ) + + self.assertEqual(int(stored_per_token), per_token) + self.assertEqual(int(stored_per_ip), per_ip) + self.assertEqual(int(stored_window), window) + + def test_scope_limits_override_global(self): + """Test that scope limits override global limits.""" + set_rate_limits(per_token=100, per_ip=10, window=60) + scope = "priority_scope" + set_rate_limits(per_token=100, per_ip=3, window=60, scope=scope) + + @self.app.route("/global") + @ratelimit() + def global_endpoint(): + return "OK" + + @self.app.route("/scope-priority") + @ratelimit(scope=scope) + def scope_priority_endpoint(): + return "OK" + + @self.app.route("/decorator-priority") + @ratelimit(scope=scope, per_ip_limit=2) + def decorator_priority_endpoint(): + return "OK" + + client = self.app.test_client() + + response = client.get("/global") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.headers["X-RateLimit-Limit"], "10") + + response = client.get("/scope-priority") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.headers["X-RateLimit-Limit"], "3") + + response = client.get("/decorator-priority") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.headers["X-RateLimit-Limit"], "2") diff --git a/brainzutils/ratelimit.py b/brainzutils/ratelimit.py index 280e4dd..9c910ff 100644 --- a/brainzutils/ratelimit.py +++ b/brainzutils/ratelimit.py @@ -6,6 +6,7 @@ # import time from functools import update_wrapper +from typing import Literal from flask import request, g from werkzeug.exceptions import TooManyRequests @@ -17,6 +18,11 @@ ratelimit_timeout = "rate_limits_timeout" # Defaults +ratelimit_defaults = { + "per_token": 50, + "per_ip": 30, + "window": 10 +} ratelimit_per_token_default = 50 ratelimit_per_ip_default = 30 ratelimit_window_default = 10 @@ -72,6 +78,26 @@ def after_request_callbacks(response): def index(): return 'test' + You can also pass custom rate limit parameters directly to the decorator to override + the global/cached values:: + + @app.route('/expensive') + @ratelimit(per_token_limit=10, per_ip_limit=5, window=60) + def expensive_endpoint(): + return 'This endpoint has stricter rate limits' + + Use the scope parameter to isolate rate limits for different endpoints:: + + @app.route('/api/v1/search') + @ratelimit(scope='search') + def search(): + return 'Search results' + + @app.route('/api/v1/upload') + @ratelimit(scope='upload', per_ip_limit=5, window=60) + def upload(): + return 'Upload complete' + 3. The default rate limits are defined above (see comment Defaults). If you want to set different rate limits, which can be also done dynamically without restarting the application, call the set_rate_limits function:: @@ -80,6 +106,18 @@ def index(): set_rate_limits(per_token_limit, per_ip_limit, rate_limit_window) + You can also set scope-specific limits in cache:: + + from brainzutils.ratelimit import set_rate_limits + + set_rate_limits(per_token=100, per_ip=50, window=60, scope='search') + set_rate_limits(per_token=10, per_ip=5, window=120, scope='upload') + + Limit resolution order (first non-None value wins): + 1. Decorator parameters (per_token_limit, per_ip_limit, window) + 2. Scope-specific limits from cache (if scope is provided) + 3. Global limits from cache + 4. To enable token based rate limiting, callers need to pass the Authorization header (see above) and the application needs to provide a user validation function:: @@ -114,148 +152,190 @@ def __init__(self, key_prefix, limit, per): def set_user_validation_function(func): - ''' + """ The function passed to this method should accept on argument, the Authorization header contents and return a True/False value if this user is a valid user. - ''' - + """ global ratelimit_user_validation ratelimit_user_validation = func -def set_rate_limits(per_token, per_ip, window): - ''' - Update the current rate limits. This will affect all new rate limiting windows and existing windows will not be changed. - ''' - cache.set(ratelimit_per_token_key, per_token, expirein=0, namespace=ratelimit_cache_namespace) - cache.set(ratelimit_per_ip_key, per_ip, expirein=0, namespace=ratelimit_cache_namespace) - cache.set(ratelimit_window_key, window, expirein=0, namespace=ratelimit_cache_namespace) +def set_rate_limits(per_token, per_ip, window, scope=None): + """ + Update the current global rate limits. This will affect all new rate limiting windows + and existing windows will not be changed. If a scope is provided, the limits will be + changed only for that scope. + """ + prefix = f"{scope}:" if scope else "" + cache.set_many({ + f"{prefix}{ratelimit_per_token_key}": per_token, + f"{prefix}{ratelimit_per_ip_key}": per_ip, + f"{prefix}{ratelimit_window_key}": window, + }, expirein=0, namespace=ratelimit_cache_namespace) + + +def get_rate_limits(scope=None): + """ + Get rate limits for global or specific scope from cache. + + Args: + scope: The scope name to get limits for + + Returns: + A dict with 'per_token', 'per_ip', 'window' keys if scope limits exist, + or None if no scope-specific limits are set. + """ + prefix = f"{scope}:" if scope else "" + keys = { + "per_token": f"{prefix}{ratelimit_per_token_key}", + "per_ip": f"{prefix}{ratelimit_per_ip_key}", + "window": f"{prefix}{ratelimit_window_key}" + } + cache_values = cache.get_many(list(keys.values()), namespace=ratelimit_cache_namespace) + + result = {} + for key in keys: + if (value := cache_values.get(keys[key])) is not None: + result[key] = value + # if returning global rate limits and value is not in cache, + # return global defaults + elif scope is None: + result[key] = ratelimit_defaults[key] + else: + result[key] = None + + return result def inject_x_rate_headers(response): - ''' + """ Add rate limit headers to responses - ''' + """ limit = get_view_rate_limit() if limit: h = response.headers - h.add('Access-Control-Expose-Headers', 'X-RateLimit-Remaining,X-RateLimit-Limit,X-RateLimit-Reset,X-RateLimit-Reset-In') - h.add('X-RateLimit-Remaining', str(limit.remaining)) - h.add('X-RateLimit-Limit', str(limit.limit)) - h.add('X-RateLimit-Reset', str(limit.reset)) - h.add('X-RateLimit-Reset-In', str(limit.seconds_before_reset)) + h.add("Access-Control-Expose-Headers", "X-RateLimit-Remaining,X-RateLimit-Limit,X-RateLimit-Reset,X-RateLimit-Reset-In") + h.add("X-RateLimit-Remaining", str(limit.remaining)) + h.add("X-RateLimit-Limit", str(limit.limit)) + h.add("X-RateLimit-Reset", str(limit.reset)) + h.add("X-RateLimit-Reset-In", str(limit.seconds_before_reset)) return response def get_view_rate_limit(): - ''' + """ Helper function to fetch the ratelimit limits from the flask context - ''' - return getattr(g, '_view_rate_limit', None) + """ + return getattr(g, "_view_rate_limit", None) def on_over_limit(limit): - ''' + """ Set a nice and readable error message for over the limit requests. - ''' + """ raise TooManyRequests( - 'You have exceeded your rate limit. See the X-RateLimit-* response headers for more ' \ - 'information on your current rate limit.') - - -def check_limit_freshness(): - ''' - This function checks to see if the values we have cached in the current request context - are still fresh enough. If they've existed longer than the timeout value, refresh from - the cache. This allows us to not check the limits for each request, saving cache traffic. - ''' - - limits_timeout = getattr(g, '_' + ratelimit_timeout, 0) - if time.time() <= limits_timeout: - return - - value = int(cache.get(ratelimit_per_token_key, namespace=ratelimit_cache_namespace) or '0') - if not value: - cache.set(ratelimit_per_token_key, ratelimit_per_token_default, expirein=0, namespace=ratelimit_cache_namespace) - value = ratelimit_per_token_default - setattr(g, '_' + ratelimit_per_token_key, value) - - value = int(cache.get(ratelimit_per_ip_key, namespace=ratelimit_cache_namespace) or '0') - if not value: - cache.set(ratelimit_per_ip_key, ratelimit_per_ip_default, expirein=0, namespace=ratelimit_cache_namespace) - value = ratelimit_per_ip_default - setattr(g, '_' + ratelimit_per_ip_key, value) - - value = int(cache.get(ratelimit_window_key, namespace=ratelimit_cache_namespace) or '0') - if not value: - cache.set(ratelimit_window_key, ratelimit_window_default, expirein=0, namespace=ratelimit_cache_namespace) - value = ratelimit_window_default - setattr(g, '_' + ratelimit_window_key, value) - - setattr(g, '_' + ratelimit_timeout, int(time.time()) + ratelimit_refresh) - - -def get_per_ip_limits(): - ''' - Fetch the per IP limits from context/cache - ''' - check_limit_freshness() - return { - 'limit': getattr(g, '_' + ratelimit_per_ip_key), - 'window' : getattr(g, '_' + ratelimit_window_key), - } - - -def get_per_token_limits(): - ''' - Fetch the per token limits from context/cache - ''' - check_limit_freshness() - return { - 'limit': getattr(g, '_' + ratelimit_per_token_key), - 'window' : getattr(g, '_' + ratelimit_window_key), - } - - -def get_rate_limit_data(request): - '''Fetch key for the given request. If an Authorization header is provided, + "You have exceeded your rate limit. See the X-RateLimit-* response headers for more " + "information on your current rate limit." + ) + +def _get_rate_limit_helper( + limit_type: Literal["per_ip", "per_token"], + _global: dict, _scope: dict = None, _local: dict = None +) -> dict: + values = {} + if _local is not None and (local_limit := _local.get(limit_type)) is not None: + values["limit"] = local_limit + elif _scope is not None and (scope_limit := _scope.get(limit_type)) is not None: + values["limit"] = scope_limit + else: + values["limit"] = _global[limit_type] + + if _local is not None and (local_window := _local.get("window")) is not None: + values["window"] = local_window + elif _scope is not None and (scope_window := _scope.get("window")) is not None: + values["window"] = scope_window + else: + values["window"] = _global["window"] + + return values + + +def get_rate_limit_data(request, per_token_limit=None, per_ip_limit=None, window=None, scope=None): + """Fetch key for the given request. If an Authorization header is provided, the caller will get a better and personalized rate limit. If no header is provided, the caller will be rate limited by IP, which gets an overall lower rate limit. This should encourage callers to always provide the Authorization token - ''' + + Args: + request: The Flask request object + per_token_limit: Optional override for per-token limit (uses cache value if None) + per_ip_limit: Optional override for per-IP limit (uses cache value if None) + window: Optional override for rate limit window in seconds (uses cache value if None) + scope: Optional scope name to check for scope-specific limits in cache + + Limit resolution order (first non-None value wins): + 1. Decorator parameters (per_token_limit, per_ip_limit, window) + 2. Scope-specific limits from cache (if scope is provided) + 3. Global limits from cache + """ + _global = get_rate_limits() + _scope = get_rate_limits(scope) if scope else None + _local = { + "per_token": per_token_limit, + "per_ip": per_ip_limit, + "window": window + } # If a user verification function is provided, parse the Authorization header and try to look up that user if ratelimit_user_validation: - auth_header = request.headers.get('Authorization') + auth_header = request.headers.get("Authorization") if auth_header: auth_token = auth_header[6:] is_valid = ratelimit_user_validation(auth_token) if is_valid: - values = get_per_token_limits() - values['key'] = auth_token + values = _get_rate_limit_helper( + "per_token", _global=_global, _scope=_scope, _local=_local + ) + values["key"] = auth_token return values - # no valid auth token provided. Look for a remote addr header provided a the proxy # or if that isn't available use the IP address from the header - ip = request.environ.get('REMOTE_ADDR', None) + ip = request.environ.get("REMOTE_ADDR", None) if not ip: ip = request.remote_addr - values = get_per_ip_limits() - values['key'] = ip + values = _get_rate_limit_helper( + "per_ip", _global=_global, _scope=_scope, _local=_local + ) + values["key"] = ip return values -def ratelimit(): - ''' - This is the decorator that should be applied to all view functions that should be +def ratelimit(per_token_limit=None, per_ip_limit=None, window=None, scope=None): + """ + This is the decorator that should be applied to all view functions that should be rate limited. - ''' + + Args: + per_token_limit: Optional override for per-token limit (uses cache value if None) + per_ip_limit: Optional override for per-IP limit (uses cache value if None) + window: Optional override for rate limit window in seconds (uses cache value if None) + scope: Optional scope to isolate rate limits for different endpoints. + If provided, the rate limit key will be scoped with this value, + allowing different endpoints to have separate rate limit buckets.. + """ def decorator(f): def rate_limited(*args, **kwargs): - data = get_rate_limit_data(request) - rlimit = RateLimit(data['key'], data['limit'], data['window']) + data = get_rate_limit_data( + request, + per_token_limit=per_token_limit, + per_ip_limit=per_ip_limit, + window=window, + scope=scope + ) + key = f"{scope}:{data['key']}" if scope else data["key"] + rlimit = RateLimit(key, data["limit"], data["window"]) g._view_rate_limit = rlimit if rlimit.over_limit: return on_over_limit(rlimit)