Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1832,6 +1832,15 @@ definitions:
description: The headers to match.
type: object
additionalProperties: true
weight:
title: Weight
description: >
The weight of a request matching this matcher when acquiring a call from the rate limiter.
Different endpoints can consume different amounts from a shared budget by specifying
different weights. If not set, each request counts as 1.
anyOf:
- type: integer
- type: string
additionalProperties: true
DefaultErrorHandler:
title: Default Error Handler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,11 @@ class Config:
headers: Optional[Dict[str, Any]] = Field(
None, description="The headers to match.", title="Headers"
)
weight: Optional[Union[int, str]] = Field(
None,
description="The weight of a request matching this matcher when acquiring a call from the rate limiter. Different endpoints can consume different amounts from a shared budget by specifying different weights. If not set, each request counts as 1.",
title="Weight",
)


class DpathExtractor(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4387,12 +4387,21 @@ def create_rate(self, model: RateModel, config: Config, **kwargs: Any) -> Rate:
def create_http_request_matcher(
self, model: HttpRequestRegexMatcherModel, config: Config, **kwargs: Any
) -> HttpRequestRegexMatcher:
weight = model.weight
if weight is not None:
if isinstance(weight, str):
weight = int(InterpolatedString.create(weight, parameters={}).eval(config))
else:
weight = int(weight)
if weight < 1:
raise ValueError(f"weight must be >= 1, got {weight}")
return HttpRequestRegexMatcher(
method=model.method,
url_base=model.url_base,
url_path_pattern=model.url_path_pattern,
params=model.params,
headers=model.headers,
weight=weight,
)

def set_api_budget(self, component_definition: ComponentDefinition, config: Config) -> None:
Expand Down
39 changes: 37 additions & 2 deletions airbyte_cdk/sources/streams/call_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,22 @@ def __init__(
url_path_pattern: Optional[str] = None,
params: Optional[Mapping[str, Any]] = None,
headers: Optional[Mapping[str, Any]] = None,
weight: Optional[int] = None,
):
"""
:param method: HTTP method (e.g. "GET", "POST"); compared case-insensitively.
:param url_base: Base URL (scheme://host) that must match.
:param url_path_pattern: A regex pattern that will be applied to the path portion of the URL.
:param params: Dictionary of query parameters that must be present in the request.
:param headers: Dictionary of headers that must be present (header keys are compared case-insensitively).
:param weight: The weight of a request matching this matcher. If set, this value is used
when acquiring a call from the rate limiter, enabling cost-based rate limiting
where different endpoints consume different amounts from a shared budget.
If not set, each request counts as 1.
"""
if weight is not None and weight < 1:
raise ValueError(f"weight must be >= 1, got {weight}")
self._weight = weight
self._method = method.upper() if method else None

# Normalize the url_base if provided: remove trailing slash.
Expand Down Expand Up @@ -242,11 +250,16 @@ def __call__(self, request: Any) -> bool:

return True

@property
def weight(self) -> Optional[int]:
"""The weight of a request matching this matcher, or None if not set."""
return self._weight

def __str__(self) -> str:
regex = self._url_path_pattern.pattern if self._url_path_pattern else None
return (
f"HttpRequestRegexMatcher(method={self._method}, url_base={self._url_base}, "
f"url_path_pattern={regex}, params={self._params}, headers={self._headers})"
f"url_path_pattern={regex}, params={self._params}, headers={self._headers}, weight={self._weight})"
)


Expand All @@ -265,6 +278,22 @@ def matches(self, request: Any) -> bool:
return True
return any(matcher(request) for matcher in self._matchers)

def get_weight(self, request: Any) -> int:
"""Get the weight for a request based on the first matching matcher.

If a matcher has a weight configured, that weight is used.
Otherwise, defaults to 1.

:param request: a request object
:return: the weight for this request
"""
for matcher in self._matchers:
if matcher(request):
if isinstance(matcher, HttpRequestRegexMatcher) and matcher.weight is not None:
return matcher.weight
return 1
return 1


class UnlimitedCallRatePolicy(BaseCallRatePolicy):
"""
Expand Down Expand Up @@ -420,6 +449,11 @@ def __init__(self, rates: list[Rate], matchers: list[RequestMatcher]):
def try_acquire(self, request: Any, weight: int) -> None:
if not self.matches(request):
raise ValueError("Request does not match the policy")
lowest_limit = min(rate.limit for rate in self._bucket.rates)
if weight > lowest_limit:
raise ValueError(
f"Weight can not exceed the lowest configured rate limit ({lowest_limit})"
)

try:
self._limiter.try_acquire(request, weight=weight)
Expand Down Expand Up @@ -596,7 +630,8 @@ def _do_acquire(
# sometimes we spend all budget before a second attempt, so we have a few more attempts
for attempt in range(1, self._maximum_attempts_to_acquire):
try:
policy.try_acquire(request, weight=1)
weight = policy.get_weight(request) if isinstance(policy, BaseCallRatePolicy) else 1
policy.try_acquire(request, weight=weight)
return
except CallRateLimitHit as exc:
last_exception = exc
Expand Down
117 changes: 117 additions & 0 deletions unit_tests/sources/streams/test_call_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,13 @@ def test_http_request_matching(mocker):
users_policy.matches.side_effect = HttpRequestMatcher(
url="http://domain/api/users", method="GET"
)
users_policy.get_weight.return_value = 1
groups_policy.matches.side_effect = HttpRequestMatcher(
url="http://domain/api/groups", method="POST"
)
groups_policy.get_weight.return_value = 1
root_policy.matches.side_effect = HttpRequestMatcher(method="GET")
root_policy.get_weight.return_value = 1
api_budget = APIBudget(
policies=[
users_policy,
Expand Down Expand Up @@ -360,6 +363,120 @@ def test_with_cache(self, mocker, requests_mock):
assert MovingWindowCallRatePolicy.try_acquire.call_count == 1


class TestWeightBasedRateLimiting:
"""Tests for weight-based rate limiting where different endpoints consume different amounts from a shared budget."""

def test_matcher_weight_default_none(self):
"""HttpRequestRegexMatcher weight defaults to None when not specified."""
matcher = HttpRequestRegexMatcher(url_path_pattern=r"/api/test")
assert matcher.weight is None

def test_matcher_weight_is_stored(self):
"""HttpRequestRegexMatcher stores the weight value when provided."""
matcher = HttpRequestRegexMatcher(url_path_pattern=r"/api/test", weight=60)
assert matcher.weight == 60

def test_matcher_rejects_zero_weight(self):
"""HttpRequestRegexMatcher raises ValueError for weight=0."""
with pytest.raises(ValueError, match="weight must be >= 1"):
HttpRequestRegexMatcher(url_path_pattern=r"/api/test", weight=0)

def test_matcher_rejects_negative_weight(self):
"""HttpRequestRegexMatcher raises ValueError for negative weight."""
with pytest.raises(ValueError, match="weight must be >= 1"):
HttpRequestRegexMatcher(url_path_pattern=r"/api/test", weight=-5)

def test_policy_get_weight_returns_matcher_weight(self):
"""BaseCallRatePolicy.get_weight returns weight from the matching matcher."""
policy = MovingWindowCallRatePolicy(
matchers=[HttpRequestRegexMatcher(url_path_pattern=r"/api/expensive", weight=120)],
rates=[Rate(1000, timedelta(hours=1))],
)
req = Request("GET", "https://example.com/api/expensive")
assert policy.get_weight(req) == 120

def test_policy_get_weight_defaults_to_1(self):
"""BaseCallRatePolicy.get_weight returns 1 when no matcher has a weight set."""
policy = MovingWindowCallRatePolicy(
matchers=[HttpRequestRegexMatcher(url_path_pattern=r"/api/default")],
rates=[Rate(1000, timedelta(hours=1))],
)
req = Request("GET", "https://example.com/api/default")
assert policy.get_weight(req) == 1

def test_policy_get_weight_no_matching_matcher(self):
"""BaseCallRatePolicy.get_weight returns 1 when no matcher matches the request."""
policy = MovingWindowCallRatePolicy(
matchers=[HttpRequestRegexMatcher(url_path_pattern=r"/api/other", weight=50)],
rates=[Rate(1000, timedelta(hours=1))],
)
req = Request("GET", "https://example.com/api/unmatched")
assert policy.get_weight(req) == 1

def test_api_budget_uses_weight(self):
"""APIBudget._do_acquire passes the matcher's weight to try_acquire."""
policy = MovingWindowCallRatePolicy(
matchers=[HttpRequestRegexMatcher(url_path_pattern=r"/api/heavy", weight=10)],
rates=[Rate(100, timedelta(hours=1))],
)
budget = APIBudget(policies=[policy])

# Make requests — each weighs 10 from the budget of 100
for i in range(10):
budget.acquire_call(Request("GET", "https://example.com/api/heavy"), block=False)

# The 11th request should exceed the budget (10 * 10 = 100, one more = 110 > 100)
with pytest.raises(CallRateLimitHit):
budget.acquire_call(Request("GET", "https://example.com/api/heavy"), block=False)

def test_weight_1_backward_compatible(self):
"""When weight is not set, behavior is identical to the old hardcoded weight=1."""
policy = MovingWindowCallRatePolicy(
matchers=[HttpRequestRegexMatcher(url_path_pattern=r"/api/normal")],
rates=[Rate(5, timedelta(hours=1))],
)
budget = APIBudget(policies=[policy])

for i in range(5):
budget.acquire_call(Request("GET", "https://example.com/api/normal"), block=False)

with pytest.raises(CallRateLimitHit):
budget.acquire_call(Request("GET", "https://example.com/api/normal"), block=False)

def test_shared_budget_different_weights(self):
"""Multiple matchers with different weights sharing one policy correctly consume the shared budget."""
# Shared policy matches both endpoints via regex
policy = MovingWindowCallRatePolicy(
matchers=[
HttpRequestRegexMatcher(url_path_pattern=r"/api/cheap", weight=1),
HttpRequestRegexMatcher(url_path_pattern=r"/api/expensive", weight=10),
],
rates=[Rate(20, timedelta(hours=1))],
)
budget = APIBudget(policies=[policy])

# Make 1 expensive request (weight 10) and 10 cheap requests (weight 1 each) = total 20
budget.acquire_call(Request("GET", "https://example.com/api/expensive"), block=False)
for i in range(10):
budget.acquire_call(Request("GET", "https://example.com/api/cheap"), block=False)

# Budget is now at 20/20 — any further request should fail
with pytest.raises(CallRateLimitHit):
budget.acquire_call(Request("GET", "https://example.com/api/cheap"), block=False)

def test_moving_window_rejects_weight_exceeding_limit(self):
"""MovingWindowCallRatePolicy raises ValueError when weight exceeds the lowest configured rate limit."""
policy = MovingWindowCallRatePolicy(
matchers=[HttpRequestRegexMatcher(url_path_pattern=r"/api/heavy", weight=50)],
rates=[Rate(10, timedelta(hours=1)), Rate(100, timedelta(days=1))],
)
req = Request("GET", "https://example.com/api/heavy")
with pytest.raises(
ValueError, match="Weight can not exceed the lowest configured rate limit"
):
policy.try_acquire(req, weight=50)


class TestHttpRequestRegexMatcher:
"""
Tests for the new regex-based logic:
Expand Down
Loading