Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions descope/descope_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import os
import warnings
from typing import Iterable
Expand All @@ -19,8 +20,13 @@
from descope.common import DEFAULT_TIMEOUT_SECONDS, AccessKeyLoginOptions, EndpointsV1
from descope.exceptions import ERROR_TYPE_INVALID_ARGUMENT, AuthException
from descope.http_client import HTTPClient
from descope.management.common import MgmtV1
from descope.mgmt import MGMT # noqa: F401

logger = logging.getLogger(__name__)

LICENSE_HANDSHAKE_TIMEOUT_SECONDS = 5.0


class DescopeClient:
ALGORITHM_KEY = "alg"
Expand Down Expand Up @@ -106,6 +112,38 @@
self._auth_http_client = auth_http_client
self._mgmt_http_client = mgmt_http_client

# Synchronous license handshake so the first management request after
# construction can carry the x-descope-license header. Backend skips
# license-header validation for the GetLicense endpoint itself, so the
# initial request is safe before the tier is cached.
if mgmt_http_client.management_key:
self._fetch_rate_limit_tier()

def _fetch_rate_limit_tier(self) -> None:
try:
response = httpx.get(
f"{self._mgmt_http_client.base_url}{MgmtV1.license_get_path}",
headers={
"Authorization": (
f"Bearer {self._mgmt_http_client.project_id}:{self._mgmt_http_client.management_key}"
)
},
follow_redirects=True,
verify=self._mgmt_http_client.client_verify,
timeout=LICENSE_HANDSHAKE_TIMEOUT_SECONDS,
)
if not response.is_success:
logger.warning(
"License handshake returned non-success status %s",
response.status_code,
)
return
tier = response.json().get("rateLimitTier")

Check warning on line 141 in descope/descope_client.py

View workflow job for this annotation

GitHub Actions / Coverage

This line has no coverage
if tier:

Check warning on line 142 in descope/descope_client.py

View workflow job for this annotation

GitHub Actions / Coverage

This line has no coverage
self._mgmt_http_client.rate_limit_tier = tier

Check warning on line 143 in descope/descope_client.py

View workflow job for this annotation

GitHub Actions / Coverage

This line has no coverage
except Exception as e:
logger.warning("License handshake failed: %s", e)

@property
def mgmt(self):
return self._mgmt
Expand Down
6 changes: 6 additions & 0 deletions descope/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ def __init__(
self.management_key = management_key
self.verbose = verbose
self._thread_local = threading.local()
# Populated by the license handshake when a management key is configured.
# Sent in the x-descope-license header so Cloudflare can apply the right
# rate limit bucket per customer tier.
self.rate_limit_tier: str | None = None

# Setup SSL verification for httpx (backwards compatibility with requests)
self.client_verify: bool | ssl.SSLContext = False
Expand Down Expand Up @@ -400,4 +404,6 @@ def _get_default_headers(self, pswd: str | None = None):
if self.management_key:
bearer = f"{bearer}:{self.management_key}"
headers["Authorization"] = f"Bearer {bearer}"
if self.rate_limit_tier:
headers["x-descope-license"] = self.rate_limit_tier
return headers
3 changes: 3 additions & 0 deletions descope/management/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ class MgmtV1:
mgmt_key_delete_path = "/v1/mgmt/managementkey/delete"
mgmt_key_search_path = "/v1/mgmt/managementkey/search"

# license
license_get_path = "/v1/mgmt/license"


class MgmtSignUpOptions:
def __init__(
Expand Down
19 changes: 19 additions & 0 deletions descope/management/license.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from descope._http_base import HTTPBase
from descope.management.common import MgmtV1


class License(HTTPBase):
def get(self) -> dict:
"""
Fetch the rate limit tier for the project's company license.

Returns a dict with a ``rateLimitTier`` field whose value is one of
``tier1`` (free), ``tier2`` (pro), ``tier3`` (growth), or ``tier4``
(enterprise). The SDK sends this value in the ``x-descope-license``
header on every management request so Cloudflare can apply the right
rate limit bucket.
"""
response = self._http.get(MgmtV1.license_get_path)
return response.json()
7 changes: 7 additions & 0 deletions descope/mgmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from descope.management.flow import Flow
from descope.management.group import Group
from descope.management.jwt import JWT
from descope.management.license import License
from descope.management.management_key import ManagementKey
from descope.management.outbound_application import (
OutboundApplication,
Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(self, http_client: HTTPClient, auth: Auth, fga_cache_url: Optional[
self._flow = Flow(http_client)
self._group = Group(http_client)
self._jwt = JWT(http_client, auth=auth)
self._license = License(http_client)
self._management_key = ManagementKey(http_client)
self._outbound_application = OutboundApplication(http_client)
self._outbound_application_by_token = OutboundApplicationByToken(http_client)
Expand Down Expand Up @@ -94,6 +96,11 @@ def jwt(self):
self._ensure_management_key("jwt")
return self._jwt

@property
def license(self):
self._ensure_management_key("license")
return self._license

@property
def permission(self):
self._ensure_management_key("permission")
Expand Down
85 changes: 85 additions & 0 deletions tests/management/test_license.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from unittest import mock
from unittest.mock import patch

from descope import AuthException, DescopeClient
from descope.common import DEFAULT_TIMEOUT_SECONDS
from descope.management.common import MgmtV1

from .. import common
from ..testutils import SSLMatcher


class TestLicense(common.DescopeTest):
def setUp(self) -> None:
super().setUp()
self.dummy_project_id = "dummy"
self.dummy_management_key = "key"
self.public_key_dict = {
"alg": "ES384",
"crv": "P-384",
"kid": "P2CtzUhdqpIF2ys9gg7ms06UvtC4",
"kty": "EC",
"use": "sig",
"x": "pX1l7nT2turcK5_Cdzos8SKIhpLh1Wy9jmKAVyMFiOCURoj-WQX1J0OUQqMsQO0s",
"y": "B0_nWAv2pmG_PzoH3-bSYZZzLNKUA0RoE2SH7DaS0KV4rtfWZhYd0MEr0xfdGKx0",
}

def test_get_failure(self):
client = DescopeClient(
self.dummy_project_id,
self.public_key_dict,
False,
self.dummy_management_key,
)
with patch("httpx.get") as mock_get:
mock_get.return_value.is_success = False
self.assertRaises(AuthException, client.mgmt.license.get)

def test_get_success(self):
client = DescopeClient(
self.dummy_project_id,
self.public_key_dict,
False,
self.dummy_management_key,
)
with patch("httpx.get") as mock_get:
network_resp = mock.Mock()
network_resp.is_success = True
network_resp.json.return_value = {"rateLimitTier": "tier4"}
mock_get.return_value = network_resp

resp = client.mgmt.license.get()
self.assertEqual(resp, {"rateLimitTier": "tier4"})

mock_get.assert_called_with(
f"{client._mgmt_http_client.base_url}{MgmtV1.license_get_path}",
headers=mock.ANY,
params=None,
follow_redirects=True,
verify=SSLMatcher(),
timeout=DEFAULT_TIMEOUT_SECONDS,
)

def test_header_injected_after_handshake(self):
client = DescopeClient(
self.dummy_project_id,
self.public_key_dict,
False,
self.dummy_management_key,
)
# Simulate a completed handshake by setting the cached tier directly.
client._mgmt_http_client.rate_limit_tier = "tier2"
headers = client._mgmt_http_client._get_default_headers()
self.assertEqual(headers.get("x-descope-license"), "tier2")

def test_header_absent_when_tier_not_cached(self):
client = DescopeClient(
self.dummy_project_id,
self.public_key_dict,
False,
self.dummy_management_key,
)
# Default state has no rate limit tier yet.
client._mgmt_http_client.rate_limit_tier = None
headers = client._mgmt_http_client._get_default_headers()
self.assertNotIn("x-descope-license", headers)
Loading