diff --git a/descope/descope_client.py b/descope/descope_client.py index ace79755c..609bf1824 100644 --- a/descope/descope_client.py +++ b/descope/descope_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import os import warnings from typing import Iterable @@ -19,8 +20,13 @@ from descope.common import DEFAULT_TIMEOUT_SECONDS, AccessKeyLoginOptions, EndpointsV1 from descope.exceptions import ERROR_TYPE_INVALID_ARGUMENT, AuthException from descope.http_client import HTTPClient +from descope.management.common import MgmtV1 from descope.mgmt import MGMT # noqa: F401 +logger = logging.getLogger(__name__) + +LICENSE_HANDSHAKE_TIMEOUT_SECONDS = 5.0 + class DescopeClient: ALGORITHM_KEY = "alg" @@ -106,6 +112,38 @@ def __init__( self._auth_http_client = auth_http_client self._mgmt_http_client = mgmt_http_client + # Synchronous license handshake so the first management request after + # construction can carry the x-descope-license header. Backend skips + # license-header validation for the GetLicense endpoint itself, so the + # initial request is safe before the tier is cached. + if mgmt_http_client.management_key: + self._fetch_rate_limit_tier() + + def _fetch_rate_limit_tier(self) -> None: + try: + response = httpx.get( + f"{self._mgmt_http_client.base_url}{MgmtV1.license_get_path}", + headers={ + "Authorization": ( + f"Bearer {self._mgmt_http_client.project_id}:{self._mgmt_http_client.management_key}" + ) + }, + follow_redirects=True, + verify=self._mgmt_http_client.client_verify, + timeout=LICENSE_HANDSHAKE_TIMEOUT_SECONDS, + ) + if not response.is_success: + logger.warning( + "License handshake returned non-success status %s", + response.status_code, + ) + return + tier = response.json().get("rateLimitTier") + if tier: + self._mgmt_http_client.rate_limit_tier = tier + except Exception as e: + logger.warning("License handshake failed: %s", e) + @property def mgmt(self): return self._mgmt diff --git a/descope/http_client.py b/descope/http_client.py index b9665a810..7562db188 100644 --- a/descope/http_client.py +++ b/descope/http_client.py @@ -174,6 +174,10 @@ def __init__( self.management_key = management_key self.verbose = verbose self._thread_local = threading.local() + # Populated by the license handshake when a management key is configured. + # Sent in the x-descope-license header so Cloudflare can apply the right + # rate limit bucket per customer tier. + self.rate_limit_tier: str | None = None # Setup SSL verification for httpx (backwards compatibility with requests) self.client_verify: bool | ssl.SSLContext = False @@ -400,4 +404,6 @@ def _get_default_headers(self, pswd: str | None = None): if self.management_key: bearer = f"{bearer}:{self.management_key}" headers["Authorization"] = f"Bearer {bearer}" + if self.rate_limit_tier: + headers["x-descope-license"] = self.rate_limit_tier return headers diff --git a/descope/management/common.py b/descope/management/common.py index eb3af4c91..266c5d317 100644 --- a/descope/management/common.py +++ b/descope/management/common.py @@ -282,6 +282,9 @@ class MgmtV1: mgmt_key_delete_path = "/v1/mgmt/managementkey/delete" mgmt_key_search_path = "/v1/mgmt/managementkey/search" + # license + license_get_path = "/v1/mgmt/license" + class MgmtSignUpOptions: def __init__( diff --git a/descope/management/license.py b/descope/management/license.py new file mode 100644 index 000000000..1ae34bb2f --- /dev/null +++ b/descope/management/license.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from descope._http_base import HTTPBase +from descope.management.common import MgmtV1 + + +class License(HTTPBase): + def get(self) -> dict: + """ + Fetch the rate limit tier for the project's company license. + + Returns a dict with a ``rateLimitTier`` field whose value is one of + ``tier1`` (free), ``tier2`` (pro), ``tier3`` (growth), or ``tier4`` + (enterprise). The SDK sends this value in the ``x-descope-license`` + header on every management request so Cloudflare can apply the right + rate limit bucket. + """ + response = self._http.get(MgmtV1.license_get_path) + return response.json() diff --git a/descope/mgmt.py b/descope/mgmt.py index 3ff1fd7f2..a992c7ebb 100644 --- a/descope/mgmt.py +++ b/descope/mgmt.py @@ -11,6 +11,7 @@ from descope.management.flow import Flow from descope.management.group import Group from descope.management.jwt import JWT +from descope.management.license import License from descope.management.management_key import ManagementKey from descope.management.outbound_application import ( OutboundApplication, @@ -45,6 +46,7 @@ def __init__(self, http_client: HTTPClient, auth: Auth, fga_cache_url: Optional[ self._flow = Flow(http_client) self._group = Group(http_client) self._jwt = JWT(http_client, auth=auth) + self._license = License(http_client) self._management_key = ManagementKey(http_client) self._outbound_application = OutboundApplication(http_client) self._outbound_application_by_token = OutboundApplicationByToken(http_client) @@ -94,6 +96,11 @@ def jwt(self): self._ensure_management_key("jwt") return self._jwt + @property + def license(self): + self._ensure_management_key("license") + return self._license + @property def permission(self): self._ensure_management_key("permission") diff --git a/tests/management/test_license.py b/tests/management/test_license.py new file mode 100644 index 000000000..75c78f743 --- /dev/null +++ b/tests/management/test_license.py @@ -0,0 +1,85 @@ +from unittest import mock +from unittest.mock import patch + +from descope import AuthException, DescopeClient +from descope.common import DEFAULT_TIMEOUT_SECONDS +from descope.management.common import MgmtV1 + +from .. import common +from ..testutils import SSLMatcher + + +class TestLicense(common.DescopeTest): + def setUp(self) -> None: + super().setUp() + self.dummy_project_id = "dummy" + self.dummy_management_key = "key" + self.public_key_dict = { + "alg": "ES384", + "crv": "P-384", + "kid": "P2CtzUhdqpIF2ys9gg7ms06UvtC4", + "kty": "EC", + "use": "sig", + "x": "pX1l7nT2turcK5_Cdzos8SKIhpLh1Wy9jmKAVyMFiOCURoj-WQX1J0OUQqMsQO0s", + "y": "B0_nWAv2pmG_PzoH3-bSYZZzLNKUA0RoE2SH7DaS0KV4rtfWZhYd0MEr0xfdGKx0", + } + + def test_get_failure(self): + client = DescopeClient( + self.dummy_project_id, + self.public_key_dict, + False, + self.dummy_management_key, + ) + with patch("httpx.get") as mock_get: + mock_get.return_value.is_success = False + self.assertRaises(AuthException, client.mgmt.license.get) + + def test_get_success(self): + client = DescopeClient( + self.dummy_project_id, + self.public_key_dict, + False, + self.dummy_management_key, + ) + with patch("httpx.get") as mock_get: + network_resp = mock.Mock() + network_resp.is_success = True + network_resp.json.return_value = {"rateLimitTier": "tier4"} + mock_get.return_value = network_resp + + resp = client.mgmt.license.get() + self.assertEqual(resp, {"rateLimitTier": "tier4"}) + + mock_get.assert_called_with( + f"{client._mgmt_http_client.base_url}{MgmtV1.license_get_path}", + headers=mock.ANY, + params=None, + follow_redirects=True, + verify=SSLMatcher(), + timeout=DEFAULT_TIMEOUT_SECONDS, + ) + + def test_header_injected_after_handshake(self): + client = DescopeClient( + self.dummy_project_id, + self.public_key_dict, + False, + self.dummy_management_key, + ) + # Simulate a completed handshake by setting the cached tier directly. + client._mgmt_http_client.rate_limit_tier = "tier2" + headers = client._mgmt_http_client._get_default_headers() + self.assertEqual(headers.get("x-descope-license"), "tier2") + + def test_header_absent_when_tier_not_cached(self): + client = DescopeClient( + self.dummy_project_id, + self.public_key_dict, + False, + self.dummy_management_key, + ) + # Default state has no rate limit tier yet. + client._mgmt_http_client.rate_limit_tier = None + headers = client._mgmt_http_client._get_default_headers() + self.assertNotIn("x-descope-license", headers)