diff --git a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml index 8c87508cd..833a00427 100644 --- a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml +++ b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml @@ -1832,6 +1832,15 @@ definitions: description: The headers to match. type: object additionalProperties: true + weight: + title: Weight + description: > + The weight of a request matching this matcher when acquiring a call from the rate limiter. + Different endpoints can consume different amounts from a shared budget by specifying + different weights. If not set, each request counts as 1. + anyOf: + - type: integer + - type: string additionalProperties: true DefaultErrorHandler: title: Default Error Handler diff --git a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py index 93e6865d8..df904473d 100644 --- a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py +++ b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py @@ -486,6 +486,11 @@ class Config: headers: Optional[Dict[str, Any]] = Field( None, description="The headers to match.", title="Headers" ) + weight: Optional[Union[int, str]] = Field( + None, + description="The weight of a request matching this matcher when acquiring a call from the rate limiter. Different endpoints can consume different amounts from a shared budget by specifying different weights. If not set, each request counts as 1.", + title="Weight", + ) class DpathExtractor(BaseModel): diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 560dd4056..239e5bd51 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -4387,12 +4387,21 @@ def create_rate(self, model: RateModel, config: Config, **kwargs: Any) -> Rate: def create_http_request_matcher( self, model: HttpRequestRegexMatcherModel, config: Config, **kwargs: Any ) -> HttpRequestRegexMatcher: + weight = model.weight + if weight is not None: + if isinstance(weight, str): + weight = int(InterpolatedString.create(weight, parameters={}).eval(config)) + else: + weight = int(weight) + if weight < 1: + raise ValueError(f"weight must be >= 1, got {weight}") return HttpRequestRegexMatcher( method=model.method, url_base=model.url_base, url_path_pattern=model.url_path_pattern, params=model.params, headers=model.headers, + weight=weight, ) def set_api_budget(self, component_definition: ComponentDefinition, config: Config) -> None: diff --git a/airbyte_cdk/sources/streams/call_rate.py b/airbyte_cdk/sources/streams/call_rate.py index 14f823e45..4a06db3b2 100644 --- a/airbyte_cdk/sources/streams/call_rate.py +++ b/airbyte_cdk/sources/streams/call_rate.py @@ -166,6 +166,7 @@ def __init__( url_path_pattern: Optional[str] = None, params: Optional[Mapping[str, Any]] = None, headers: Optional[Mapping[str, Any]] = None, + weight: Optional[int] = None, ): """ :param method: HTTP method (e.g. "GET", "POST"); compared case-insensitively. @@ -173,7 +174,14 @@ def __init__( :param url_path_pattern: A regex pattern that will be applied to the path portion of the URL. :param params: Dictionary of query parameters that must be present in the request. :param headers: Dictionary of headers that must be present (header keys are compared case-insensitively). + :param weight: The weight of a request matching this matcher. If set, this value is used + when acquiring a call from the rate limiter, enabling cost-based rate limiting + where different endpoints consume different amounts from a shared budget. + If not set, each request counts as 1. """ + if weight is not None and weight < 1: + raise ValueError(f"weight must be >= 1, got {weight}") + self._weight = weight self._method = method.upper() if method else None # Normalize the url_base if provided: remove trailing slash. @@ -242,11 +250,16 @@ def __call__(self, request: Any) -> bool: return True + @property + def weight(self) -> Optional[int]: + """The weight of a request matching this matcher, or None if not set.""" + return self._weight + def __str__(self) -> str: regex = self._url_path_pattern.pattern if self._url_path_pattern else None return ( f"HttpRequestRegexMatcher(method={self._method}, url_base={self._url_base}, " - f"url_path_pattern={regex}, params={self._params}, headers={self._headers})" + f"url_path_pattern={regex}, params={self._params}, headers={self._headers}, weight={self._weight})" ) @@ -265,6 +278,22 @@ def matches(self, request: Any) -> bool: return True return any(matcher(request) for matcher in self._matchers) + def get_weight(self, request: Any) -> int: + """Get the weight for a request based on the first matching matcher. + + If a matcher has a weight configured, that weight is used. + Otherwise, defaults to 1. + + :param request: a request object + :return: the weight for this request + """ + for matcher in self._matchers: + if matcher(request): + if isinstance(matcher, HttpRequestRegexMatcher) and matcher.weight is not None: + return matcher.weight + return 1 + return 1 + class UnlimitedCallRatePolicy(BaseCallRatePolicy): """ @@ -420,6 +449,11 @@ def __init__(self, rates: list[Rate], matchers: list[RequestMatcher]): def try_acquire(self, request: Any, weight: int) -> None: if not self.matches(request): raise ValueError("Request does not match the policy") + lowest_limit = min(rate.limit for rate in self._bucket.rates) + if weight > lowest_limit: + raise ValueError( + f"Weight can not exceed the lowest configured rate limit ({lowest_limit})" + ) try: self._limiter.try_acquire(request, weight=weight) @@ -596,7 +630,8 @@ def _do_acquire( # sometimes we spend all budget before a second attempt, so we have a few more attempts for attempt in range(1, self._maximum_attempts_to_acquire): try: - policy.try_acquire(request, weight=1) + weight = policy.get_weight(request) if isinstance(policy, BaseCallRatePolicy) else 1 + policy.try_acquire(request, weight=weight) return except CallRateLimitHit as exc: last_exception = exc diff --git a/unit_tests/sources/streams/test_call_rate.py b/unit_tests/sources/streams/test_call_rate.py index b99905870..a423fe573 100644 --- a/unit_tests/sources/streams/test_call_rate.py +++ b/unit_tests/sources/streams/test_call_rate.py @@ -140,10 +140,13 @@ def test_http_request_matching(mocker): users_policy.matches.side_effect = HttpRequestMatcher( url="http://domain/api/users", method="GET" ) + users_policy.get_weight.return_value = 1 groups_policy.matches.side_effect = HttpRequestMatcher( url="http://domain/api/groups", method="POST" ) + groups_policy.get_weight.return_value = 1 root_policy.matches.side_effect = HttpRequestMatcher(method="GET") + root_policy.get_weight.return_value = 1 api_budget = APIBudget( policies=[ users_policy, @@ -360,6 +363,120 @@ def test_with_cache(self, mocker, requests_mock): assert MovingWindowCallRatePolicy.try_acquire.call_count == 1 +class TestWeightBasedRateLimiting: + """Tests for weight-based rate limiting where different endpoints consume different amounts from a shared budget.""" + + def test_matcher_weight_default_none(self): + """HttpRequestRegexMatcher weight defaults to None when not specified.""" + matcher = HttpRequestRegexMatcher(url_path_pattern=r"/api/test") + assert matcher.weight is None + + def test_matcher_weight_is_stored(self): + """HttpRequestRegexMatcher stores the weight value when provided.""" + matcher = HttpRequestRegexMatcher(url_path_pattern=r"/api/test", weight=60) + assert matcher.weight == 60 + + def test_matcher_rejects_zero_weight(self): + """HttpRequestRegexMatcher raises ValueError for weight=0.""" + with pytest.raises(ValueError, match="weight must be >= 1"): + HttpRequestRegexMatcher(url_path_pattern=r"/api/test", weight=0) + + def test_matcher_rejects_negative_weight(self): + """HttpRequestRegexMatcher raises ValueError for negative weight.""" + with pytest.raises(ValueError, match="weight must be >= 1"): + HttpRequestRegexMatcher(url_path_pattern=r"/api/test", weight=-5) + + def test_policy_get_weight_returns_matcher_weight(self): + """BaseCallRatePolicy.get_weight returns weight from the matching matcher.""" + policy = MovingWindowCallRatePolicy( + matchers=[HttpRequestRegexMatcher(url_path_pattern=r"/api/expensive", weight=120)], + rates=[Rate(1000, timedelta(hours=1))], + ) + req = Request("GET", "https://example.com/api/expensive") + assert policy.get_weight(req) == 120 + + def test_policy_get_weight_defaults_to_1(self): + """BaseCallRatePolicy.get_weight returns 1 when no matcher has a weight set.""" + policy = MovingWindowCallRatePolicy( + matchers=[HttpRequestRegexMatcher(url_path_pattern=r"/api/default")], + rates=[Rate(1000, timedelta(hours=1))], + ) + req = Request("GET", "https://example.com/api/default") + assert policy.get_weight(req) == 1 + + def test_policy_get_weight_no_matching_matcher(self): + """BaseCallRatePolicy.get_weight returns 1 when no matcher matches the request.""" + policy = MovingWindowCallRatePolicy( + matchers=[HttpRequestRegexMatcher(url_path_pattern=r"/api/other", weight=50)], + rates=[Rate(1000, timedelta(hours=1))], + ) + req = Request("GET", "https://example.com/api/unmatched") + assert policy.get_weight(req) == 1 + + def test_api_budget_uses_weight(self): + """APIBudget._do_acquire passes the matcher's weight to try_acquire.""" + policy = MovingWindowCallRatePolicy( + matchers=[HttpRequestRegexMatcher(url_path_pattern=r"/api/heavy", weight=10)], + rates=[Rate(100, timedelta(hours=1))], + ) + budget = APIBudget(policies=[policy]) + + # Make requests — each weighs 10 from the budget of 100 + for i in range(10): + budget.acquire_call(Request("GET", "https://example.com/api/heavy"), block=False) + + # The 11th request should exceed the budget (10 * 10 = 100, one more = 110 > 100) + with pytest.raises(CallRateLimitHit): + budget.acquire_call(Request("GET", "https://example.com/api/heavy"), block=False) + + def test_weight_1_backward_compatible(self): + """When weight is not set, behavior is identical to the old hardcoded weight=1.""" + policy = MovingWindowCallRatePolicy( + matchers=[HttpRequestRegexMatcher(url_path_pattern=r"/api/normal")], + rates=[Rate(5, timedelta(hours=1))], + ) + budget = APIBudget(policies=[policy]) + + for i in range(5): + budget.acquire_call(Request("GET", "https://example.com/api/normal"), block=False) + + with pytest.raises(CallRateLimitHit): + budget.acquire_call(Request("GET", "https://example.com/api/normal"), block=False) + + def test_shared_budget_different_weights(self): + """Multiple matchers with different weights sharing one policy correctly consume the shared budget.""" + # Shared policy matches both endpoints via regex + policy = MovingWindowCallRatePolicy( + matchers=[ + HttpRequestRegexMatcher(url_path_pattern=r"/api/cheap", weight=1), + HttpRequestRegexMatcher(url_path_pattern=r"/api/expensive", weight=10), + ], + rates=[Rate(20, timedelta(hours=1))], + ) + budget = APIBudget(policies=[policy]) + + # Make 1 expensive request (weight 10) and 10 cheap requests (weight 1 each) = total 20 + budget.acquire_call(Request("GET", "https://example.com/api/expensive"), block=False) + for i in range(10): + budget.acquire_call(Request("GET", "https://example.com/api/cheap"), block=False) + + # Budget is now at 20/20 — any further request should fail + with pytest.raises(CallRateLimitHit): + budget.acquire_call(Request("GET", "https://example.com/api/cheap"), block=False) + + def test_moving_window_rejects_weight_exceeding_limit(self): + """MovingWindowCallRatePolicy raises ValueError when weight exceeds the lowest configured rate limit.""" + policy = MovingWindowCallRatePolicy( + matchers=[HttpRequestRegexMatcher(url_path_pattern=r"/api/heavy", weight=50)], + rates=[Rate(10, timedelta(hours=1)), Rate(100, timedelta(days=1))], + ) + req = Request("GET", "https://example.com/api/heavy") + with pytest.raises( + ValueError, match="Weight can not exceed the lowest configured rate limit" + ): + policy.try_acquire(req, weight=50) + + class TestHttpRequestRegexMatcher: """ Tests for the new regex-based logic: