Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 11 additions & 103 deletions brainzutils/flask/test/test_ratelimit.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,80 +112,18 @@ def index():
set_rate_limits(self.max_token_requests, self.max_ip_requests, self.ratelimit_window)
self.make_requests(client, self.max_token_requests, token="Token %s" % valid_user)

def test_custom_ip_limit(self):
"""Test that per_ip_limit parameter overrides global limit."""
custom_limit = 2

@self.app.route("/custom")
@ratelimit(per_ip_limit=custom_limit, window=60)
def custom_endpoint():
return "OK"

client = self.app.test_client()

# First request should succeed
response = client.get("/custom")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.headers["X-RateLimit-Limit"], str(custom_limit))
self.assertEqual(response.headers["X-RateLimit-Remaining"], "1")

response = client.get("/custom")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.headers["X-RateLimit-Remaining"], "0")

response = client.get("/custom")
self.assertEqual(response.status_code, 429)

def test_custom_window(self):
"""Test that window parameter works correctly."""
@self.app.route("/short-window")
@ratelimit(per_ip_limit=1, window=2)
def short_window_endpoint():
return "OK"

client = self.app.test_client()

response = client.get("/short-window")
self.assertEqual(response.status_code, 200)
response = client.get("/short-window")
self.assertEqual(response.status_code, 429)

sleep(2.5)
response = client.get("/short-window")
self.assertEqual(response.status_code, 200)

def test_headers_contain_correct_values(self):
"""Test that rate limit headers contain expected values."""
limit = 5
window = 30

@self.app.route("/headers")
@ratelimit(per_ip_limit=limit, window=window)
def headers_endpoint():
return "OK"

client = self.app.test_client()
response = client.get("/headers")

self.assertEqual(response.status_code, 200)
self.assertIn("X-RateLimit-Limit", response.headers)
self.assertIn("X-RateLimit-Remaining", response.headers)
self.assertIn("X-RateLimit-Reset-In", response.headers)

self.assertEqual(response.headers["X-RateLimit-Limit"], str(limit))
self.assertEqual(response.headers["X-RateLimit-Remaining"], str(limit - 1))
self.assertLessEqual(int(response.headers["X-RateLimit-Reset-In"]), window)
self.assertIn("X-RateLimit-Reset", response.headers)

def test_scope_isolation(self):
"""Test that different scopes have independent rate limit buckets."""
set_rate_limits(per_token=100, per_ip=2, window=60, scope="scope_a")
set_rate_limits(per_token=100, per_ip=2, window=60, scope="scope_b")

@self.app.route("/scope-a")
@ratelimit(scope="scope_a", per_ip_limit=2, window=60)
@ratelimit(scope="scope_a")
def scope_a_endpoint():
return "A"

@self.app.route("/scope-b")
@ratelimit(scope="scope_b", per_ip_limit=2, window=60)
@ratelimit(scope="scope_b")
def scope_b_endpoint():
return "B"

Expand Down Expand Up @@ -245,13 +183,16 @@ def shared_2_endpoint():

def test_no_scope_vs_scoped(self):
"""Test that unscoped and scoped endpoints have separate buckets."""
set_rate_limits(per_token=100, per_ip=1, window=60)
set_rate_limits(per_token=100, per_ip=1, window=60, scope="my_scope")

@self.app.route("/unscoped")
@ratelimit(per_ip_limit=1, window=60)
@ratelimit()
def unscoped_endpoint():
return "Unscoped"

@self.app.route("/scoped")
@ratelimit(scope="my_scope", per_ip_limit=1, window=60)
@ratelimit(scope="my_scope")
def scoped_endpoint():
return "Scoped"

Expand Down Expand Up @@ -290,31 +231,7 @@ def test_set_and_get_scope_limits(self):
self.assertIsNotNone(result)
self.assertEqual(result["per_token"], None)
self.assertEqual(result["per_ip"], None)
self.assertEqual(result["window"], None
)

def test_decorator_overrides_cache_scope_limits(self):
"""Test that decorator parameters override scope limits from cache."""
scope = "override_scope"
set_rate_limits(per_token=100, per_ip=10, window=60, scope=scope)

@self.app.route("/override")
@ratelimit(scope=scope, per_ip_limit=2)
def override_endpoint():
return "OK"

client = self.app.test_client()

response = client.get("/override")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.headers["X-RateLimit-Limit"], "2")

response = client.get("/override")
self.assertEqual(response.status_code, 200)

# 3rd request should fail (limit is 2, not 10)
response = client.get("/override")
self.assertEqual(response.status_code, 429)
self.assertEqual(result["window"], None)

def test_scope_cache_values_stored_correctly(self):
"""Test that scope limits are stored in cache with correct keys."""
Expand Down Expand Up @@ -359,11 +276,6 @@ def global_endpoint():
def scope_priority_endpoint():
return "OK"

@self.app.route("/decorator-priority")
@ratelimit(scope=scope, per_ip_limit=2)
def decorator_priority_endpoint():
return "OK"

client = self.app.test_client()

response = client.get("/global")
Expand All @@ -373,7 +285,3 @@ def decorator_priority_endpoint():
response = client.get("/scope-priority")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.headers["X-RateLimit-Limit"], "3")

response = client.get("/decorator-priority")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.headers["X-RateLimit-Limit"], "2")
53 changes: 13 additions & 40 deletions brainzutils/ratelimit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,15 @@ def after_request_callbacks(response):
def index():
return '<html><body>test</body></html>'

You can also pass custom rate limit parameters directly to the decorator to override
the global/cached values::

@app.route('/expensive')
@ratelimit(per_token_limit=10, per_ip_limit=5, window=60)
def expensive_endpoint():
return 'This endpoint has stricter rate limits'

Use the scope parameter to isolate rate limits for different endpoints::
Use the scope parameter to isolate and apply custom rate limits for different endpoints::

@app.route('/api/v1/search')
@ratelimit(scope='search')
def search():
return 'Search results'

@app.route('/api/v1/upload')
@ratelimit(scope='upload', per_ip_limit=5, window=60)
@ratelimit(scope='upload')
def upload():
return 'Upload complete'

Expand All @@ -114,9 +106,8 @@ def upload():
set_rate_limits(per_token=10, per_ip=5, window=120, scope='upload')

Limit resolution order (first non-None value wins):
1. Decorator parameters (per_token_limit, per_ip_limit, window)
2. Scope-specific limits from cache (if scope is provided)
3. Global limits from cache
1. Scope-specific limits from cache (if scope is provided)
2. Global limits from cache

4. To enable token based rate limiting, callers need to pass the Authorization header (see above)
and the application needs to provide a user validation function::
Expand Down Expand Up @@ -240,51 +231,39 @@ def on_over_limit(limit):

def _get_rate_limit_helper(
limit_type: Literal["per_ip", "per_token"],
_global: dict, _scope: dict = None, _local: dict = None
_global: dict,
_scope: dict = None
) -> dict:
values = {}
if _local is not None and (local_limit := _local.get(limit_type)) is not None:
values["limit"] = local_limit
elif _scope is not None and (scope_limit := _scope.get(limit_type)) is not None:
if _scope is not None and (scope_limit := _scope.get(limit_type)) is not None:
values["limit"] = scope_limit
else:
values["limit"] = _global[limit_type]

if _local is not None and (local_window := _local.get("window")) is not None:
values["window"] = local_window
elif _scope is not None and (scope_window := _scope.get("window")) is not None:
if _scope is not None and (scope_window := _scope.get("window")) is not None:
values["window"] = scope_window
else:
values["window"] = _global["window"]

return values


def get_rate_limit_data(request, per_token_limit=None, per_ip_limit=None, window=None, scope=None):
def get_rate_limit_data(request, scope=None):
"""Fetch key for the given request. If an Authorization header is provided,
the caller will get a better and personalized rate limit. If no header is provided,
the caller will be rate limited by IP, which gets an overall lower rate limit.
This should encourage callers to always provide the Authorization token

Args:
request: The Flask request object
per_token_limit: Optional override for per-token limit (uses cache value if None)
per_ip_limit: Optional override for per-IP limit (uses cache value if None)
window: Optional override for rate limit window in seconds (uses cache value if None)
scope: Optional scope name to check for scope-specific limits in cache

Limit resolution order (first non-None value wins):
1. Decorator parameters (per_token_limit, per_ip_limit, window)
2. Scope-specific limits from cache (if scope is provided)
1. Scope-specific limits from cache (if scope is provided)
3. Global limits from cache
"""
_global = get_rate_limits()
_scope = get_rate_limits(scope) if scope else None
_local = {
"per_token": per_token_limit,
"per_ip": per_ip_limit,
"window": window
}

# If a user verification function is provided, parse the Authorization header and try to look up that user
if ratelimit_user_validation:
Expand All @@ -294,7 +273,7 @@ def get_rate_limit_data(request, per_token_limit=None, per_ip_limit=None, window
is_valid = ratelimit_user_validation(auth_token)
if is_valid:
values = _get_rate_limit_helper(
"per_token", _global=_global, _scope=_scope, _local=_local
"per_token", _global=_global, _scope=_scope
)
values["key"] = auth_token
return values
Expand All @@ -306,21 +285,18 @@ def get_rate_limit_data(request, per_token_limit=None, per_ip_limit=None, window
ip = request.remote_addr

values = _get_rate_limit_helper(
"per_ip", _global=_global, _scope=_scope, _local=_local
"per_ip", _global=_global, _scope=_scope
)
values["key"] = ip
return values


def ratelimit(per_token_limit=None, per_ip_limit=None, window=None, scope=None):
def ratelimit(scope=None):
"""
This is the decorator that should be applied to all view functions that should be
rate limited.

Args:
per_token_limit: Optional override for per-token limit (uses cache value if None)
per_ip_limit: Optional override for per-IP limit (uses cache value if None)
window: Optional override for rate limit window in seconds (uses cache value if None)
scope: Optional scope to isolate rate limits for different endpoints.
If provided, the rate limit key will be scoped with this value,
allowing different endpoints to have separate rate limit buckets..
Expand All @@ -329,9 +305,6 @@ def decorator(f):
def rate_limited(*args, **kwargs):
data = get_rate_limit_data(
request,
per_token_limit=per_token_limit,
per_ip_limit=per_ip_limit,
window=window,
scope=scope
)
key = f"{scope}:{data['key']}" if scope else data["key"]
Expand Down