From 08e1c7ec6cf7642fb93e297e0f7c59a00a024c6b Mon Sep 17 00:00:00 2001 From: b-long Date: Wed, 4 Feb 2026 21:37:45 -0500 Subject: [PATCH] fix: implement KAS allowlist functionality --- src/otdf_python/cli.py | 71 +++++-- src/otdf_python/kas_allowlist.py | 182 +++++++++++++++++ src/otdf_python/kas_client.py | 26 ++- src/otdf_python/sdk.py | 30 +++ src/otdf_python/sdk_builder.py | 88 +++++++- tests/test_kas_allowlist.py | 333 +++++++++++++++++++++++++++++++ 6 files changed, 706 insertions(+), 24 deletions(-) create mode 100644 src/otdf_python/kas_allowlist.py create mode 100644 tests/test_kas_allowlist.py diff --git a/src/otdf_python/cli.py b/src/otdf_python/cli.py index 8037a35..6dc7f95 100644 --- a/src/otdf_python/cli.py +++ b/src/otdf_python/cli.py @@ -112,33 +112,14 @@ def load_client_credentials(creds_file_path: str) -> tuple[str, str]: ) from e -def build_sdk(args) -> SDK: - """Build SDK instance from CLI arguments.""" - builder = SDKBuilder() - - if args.platform_url: - builder.set_platform_endpoint(args.platform_url) - - # Auto-detect HTTP URLs and enable plaintext mode - if args.platform_url.startswith("http://") and ( - not hasattr(args, "plaintext") or not args.plaintext - ): - logger.debug( - f"Auto-detected HTTP URL {args.platform_url}, enabling plaintext mode" - ) - builder.use_insecure_plaintext_connection(True) - - if args.oidc_endpoint: - builder.set_issuer_endpoint(args.oidc_endpoint) - +def _configure_auth(builder: SDKBuilder, args) -> None: + """Configure authentication on the SDK builder.""" if args.client_id and args.client_secret: builder.client_secret(args.client_id, args.client_secret) elif hasattr(args, "with_client_creds_file") and args.with_client_creds_file: - # Load credentials from file client_id, client_secret = load_client_credentials(args.with_client_creds_file) builder.client_secret(client_id, client_secret) elif hasattr(args, "auth") and args.auth: - # Parse combined auth string (clientId:clientSecret) - legacy support auth_parts = args.auth.split(":") if len(auth_parts) != 2: raise CLIError( @@ -152,12 +133,49 @@ def build_sdk(args) -> SDK: "Authentication required: provide --with-client-creds-file OR --client-id and --client-secret", ) + +def _configure_kas_allowlist(builder: SDKBuilder, args) -> None: + """Configure KAS allowlist on the SDK builder.""" + if hasattr(args, "ignore_kas_allowlist") and args.ignore_kas_allowlist: + logger.warning( + "KAS allowlist validation is disabled. This may leak credentials " + "to malicious servers if decrypting untrusted TDF files." + ) + builder.with_ignore_kas_allowlist(True) + elif hasattr(args, "kas_allowlist") and args.kas_allowlist: + kas_urls = [url.strip() for url in args.kas_allowlist.split(",") if url.strip()] + logger.debug(f"Using KAS allowlist: {kas_urls}") + builder.with_kas_allowlist(kas_urls) + + +def build_sdk(args) -> SDK: + """Build SDK instance from CLI arguments.""" + builder = SDKBuilder() + + if args.platform_url: + builder.set_platform_endpoint(args.platform_url) + # Auto-detect HTTP URLs and enable plaintext mode + if args.platform_url.startswith("http://") and ( + not hasattr(args, "plaintext") or not args.plaintext + ): + logger.debug( + f"Auto-detected HTTP URL {args.platform_url}, enabling plaintext mode" + ) + builder.use_insecure_plaintext_connection(True) + + if args.oidc_endpoint: + builder.set_issuer_endpoint(args.oidc_endpoint) + + _configure_auth(builder, args) + if hasattr(args, "plaintext") and args.plaintext: builder.use_insecure_plaintext_connection(True) if args.insecure: builder.use_insecure_skip_verify(True) + _configure_kas_allowlist(builder, args) + return builder.build() @@ -476,6 +494,17 @@ def create_parser() -> argparse.ArgumentParser: security_group.add_argument( "--insecure", action="store_true", help="Skip TLS verification" ) + security_group.add_argument( + "--kas-allowlist", + help="Comma-separated list of trusted KAS URLs. " + "By default, only the platform URL's KAS endpoint is trusted.", + ) + security_group.add_argument( + "--ignore-kas-allowlist", + action="store_true", + help="WARNING: Disable KAS allowlist validation. This is insecure and " + "should only be used for testing. May leak credentials to malicious servers.", + ) # Subcommands subparsers = parser.add_subparsers(dest="command", help="Available commands") diff --git a/src/otdf_python/kas_allowlist.py b/src/otdf_python/kas_allowlist.py new file mode 100644 index 0000000..2ac1f4c --- /dev/null +++ b/src/otdf_python/kas_allowlist.py @@ -0,0 +1,182 @@ +"""KAS Allowlist: Validates KAS URLs against a list of trusted hosts. + +This module provides protection against SSRF attacks where malicious TDF files +could contain attacker-controlled KAS URLs to steal OIDC credentials. +""" + +import logging +from urllib.parse import urlparse + + +class KASAllowlist: + """Validates KAS URLs against an allowlist of trusted hosts. + + This class prevents credential theft by ensuring the SDK only sends + authentication tokens to trusted KAS endpoints. + + Example: + allowlist = KASAllowlist(["https://kas.example.com"]) + allowlist.is_allowed("https://kas.example.com/kas") # True + allowlist.is_allowed("https://evil.com/kas") # False + + """ + + def __init__(self, allowed_urls: list[str] | None = None, allow_all: bool = False): + """Initialize the KAS allowlist. + + Args: + allowed_urls: List of trusted KAS URLs. Each URL is normalized to + its origin (scheme://host:port) for comparison. + allow_all: If True, all URLs are allowed. Use only for testing. + A warning is logged when this is enabled. + + """ + self._allowed_origins: set[str] = set() + self._allow_all = allow_all + + if allow_all: + logging.warning( + "KAS allowlist is disabled (allow_all=True). " + "This is insecure and should only be used for testing." + ) + + if allowed_urls: + for url in allowed_urls: + self.add(url) + + def add(self, url: str) -> None: + """Add a URL to the allowlist. + + The URL is normalized to its origin (scheme://host:port) before storage. + Paths and query strings are stripped. + + Args: + url: The KAS URL to allow. Can include path components which + will be stripped for origin comparison. + + """ + origin = self._get_origin(url) + self._allowed_origins.add(origin) + logging.debug(f"Added KAS origin to allowlist: {origin}") + + def is_allowed(self, url: str) -> bool: + """Check if a URL is allowed by the allowlist. + + Args: + url: The KAS URL to check. + + Returns: + True if the URL's origin is in the allowlist or allow_all is True. + False otherwise. + + """ + if self._allow_all: + logging.debug(f"KAS URL allowed (allow_all=True): {url}") + return True + + if not self._allowed_origins: + logging.debug(f"KAS URL rejected (empty allowlist): {url}") + return False + + origin = self._get_origin(url) + allowed = origin in self._allowed_origins + if allowed: + logging.debug(f"KAS URL allowed: {url} (origin: {origin})") + else: + logging.debug( + f"KAS URL rejected: {url} (origin: {origin}, " + f"allowed: {self._allowed_origins})" + ) + return allowed + + def validate(self, url: str) -> None: + """Validate a URL against the allowlist, raising an exception if not allowed. + + Args: + url: The KAS URL to validate. + + Raises: + SDK.KasAllowlistException: If the URL is not in the allowlist. + + """ + if not self.is_allowed(url): + # Import here to avoid circular imports + from .sdk import SDK + + raise SDK.KasAllowlistException(url, self._allowed_origins) + + @property + def allowed_origins(self) -> set[str]: + """Return the set of allowed origins (read-only copy).""" + return self._allowed_origins.copy() + + @property + def allow_all(self) -> bool: + """Return whether all URLs are allowed.""" + return self._allow_all + + @staticmethod + def _get_origin(url: str) -> str: + """Extract the origin (scheme://host:port) from a URL. + + This normalizes URLs for comparison by stripping paths and query strings. + Default ports (80 for http, 443 for https) are included explicitly. + + Args: + url: The URL to extract the origin from. + + Returns: + Normalized origin string in format scheme://host:port + + """ + # Add scheme if missing + if "://" not in url: + url = "https://" + url + + try: + parsed = urlparse(url) + except Exception as e: + logging.warning(f"Failed to parse URL {url}: {e}") + # Return the URL as-is if parsing fails + return url.lower() + + scheme = (parsed.scheme or "https").lower() + hostname = (parsed.hostname or "").lower() + + if not hostname: + # URL might be malformed, return as-is + logging.warning(f"Could not extract hostname from URL: {url}") + return url.lower() + + # Determine port (use explicit port or default for scheme) + if parsed.port: + port = parsed.port + elif scheme == "http": + port = 80 + else: + port = 443 + + return f"{scheme}://{hostname}:{port}" + + @classmethod + def from_platform_url(cls, platform_url: str) -> "KASAllowlist": + """Create an allowlist from a platform URL. + + This is the default behavior: auto-allow the platform's KAS endpoint. + + Args: + platform_url: The OpenTDF platform URL. The KAS endpoint is + assumed to be at {platform_url}/kas. + + Returns: + KASAllowlist configured to allow the platform's KAS endpoint. + + """ + allowlist = cls() + # Add the platform URL itself (KAS might be at root or /kas) + allowlist.add(platform_url) + # Also construct the /kas endpoint explicitly + kas_url = platform_url.rstrip("/") + "/kas" + allowlist.add(kas_url) + logging.info(f"Created KAS allowlist from platform URL: {platform_url}") + return allowlist diff --git a/src/otdf_python/kas_client.py b/src/otdf_python/kas_client.py index b3ffe0d..82e90d3 100644 --- a/src/otdf_python/kas_client.py +++ b/src/otdf_python/kas_client.py @@ -38,13 +38,26 @@ def __init__( cache=None, use_plaintext=False, verify_ssl=True, + kas_allowlist=None, ): - """Initialize KAS client.""" + """Initialize KAS client. + + Args: + kas_url: Default KAS URL + token_source: Function that returns an authentication token + cache: Optional KASKeyCache for caching public keys + use_plaintext: Whether to use HTTP instead of HTTPS + verify_ssl: Whether to verify SSL certificates + kas_allowlist: Optional KASAllowlist for URL validation. If provided, + only URLs in the allowlist will be contacted. + + """ self.kas_url = kas_url self.token_source = token_source self.cache = cache or KASKeyCache() self.use_plaintext = use_plaintext self.verify_ssl = verify_ssl + self.kas_allowlist = kas_allowlist self.decryptor = None self.client_public_key = None @@ -86,15 +99,26 @@ def close(self): def _normalize_kas_url(self, url: str) -> str: """Normalize KAS URLs based on client security settings. + This method also validates the URL against the KAS allowlist if one + is configured. This prevents SSRF attacks where malicious TDF files + could contain attacker-controlled KAS URLs to steal OIDC credentials. + Args: url: The KAS URL to normalize Returns: Normalized URL with appropriate protocol and port + Raises: + KASAllowlistException: If the URL is not in the allowlist + """ from urllib.parse import urlparse + # Validate against allowlist BEFORE making any requests + if self.kas_allowlist is not None: + self.kas_allowlist.validate(url) + try: # Parse the URL parsed = urlparse(url) diff --git a/src/otdf_python/sdk.py b/src/otdf_python/sdk.py index 98ed3ff..d3476c6 100644 --- a/src/otdf_python/sdk.py +++ b/src/otdf_python/sdk.py @@ -37,6 +37,7 @@ def __init__( token_source=None, sdk_ssl_verify=True, use_plaintext=False, + kas_allowlist=None, ): """Initialize the KAS client. @@ -45,6 +46,7 @@ def __init__( token_source: Function that returns an authentication token sdk_ssl_verify: Whether to verify SSL certificates use_plaintext: Whether to use plaintext HTTP connections instead of HTTPS + kas_allowlist: Optional KASAllowlist for URL validation """ from .kas_client import KASClient @@ -54,6 +56,7 @@ def __init__( token_source=token_source, verify_ssl=sdk_ssl_verify, use_plaintext=use_plaintext, + kas_allowlist=kas_allowlist, ) # Store the parameters for potential use self._sdk_ssl_verify = sdk_ssl_verify @@ -405,6 +408,33 @@ class KasBadRequestException(SDKException): class KasAllowlistException(SDKException): """Throw when KAS allowlist check fails.""" + def __init__( + self, + url: str, + allowed_origins: set[str] | None = None, + message: str | None = None, + ): + """Initialize exception. + + Args: + url: The KAS URL that was rejected + allowed_origins: Set of allowed origin URLs + message: Optional custom message (auto-generated if not provided) + + """ + self.url = url + self.allowed_origins = allowed_origins or set() + if message is None: + origins_str = ( + ", ".join(sorted(self.allowed_origins)) + if self.allowed_origins + else "none" + ) + message = ( + f"KAS URL not in allowlist: {url}. Allowed origins: {origins_str}" + ) + super().__init__(message) + class AssertionException(SDKException): """Throw when an assertion validation fails.""" diff --git a/src/otdf_python/sdk_builder.py b/src/otdf_python/sdk_builder.py index 2234ea2..6cf74fb 100644 --- a/src/otdf_python/sdk_builder.py +++ b/src/otdf_python/sdk_builder.py @@ -10,6 +10,7 @@ import httpx +from otdf_python.kas_allowlist import KASAllowlist from otdf_python.sdk import KAS, SDK from otdf_python.sdk_exceptions import AutoConfigureException @@ -47,6 +48,8 @@ def __init__(self): self.ssl_context: ssl.SSLContext | None = None self.auth_token: str | None = None self.cert_paths: list[str] = [] + self._kas_allowlist_urls: list[str] | None = None + self._ignore_kas_allowlist: bool = False @staticmethod def new_builder() -> "SDKBuilder": @@ -201,6 +204,54 @@ def bearer_token(self, token: str) -> "SDKBuilder": self.auth_token = token return self + def with_kas_allowlist(self, urls: list[str]) -> "SDKBuilder": + """Set the KAS allowlist to restrict which KAS servers the SDK will contact. + + This provides protection against SSRF attacks where malicious TDF files + could contain attacker-controlled KAS URLs to steal OIDC credentials. + + By default (if no allowlist is set), only the platform's KAS endpoint + is allowed. + + Args: + urls: List of trusted KAS URLs. Each URL is normalized to its + origin (scheme://host:port) for comparison. + + Returns: + self: The builder instance for chaining + + Example: + builder.with_kas_allowlist([ + "https://kas.example.com", + "https://kas2.example.com:8443" + ]) + + """ + self._kas_allowlist_urls = urls + return self + + def with_ignore_kas_allowlist(self, ignore: bool = True) -> "SDKBuilder": + """Configure whether to skip KAS allowlist validation. + + WARNING: This is insecure and should only be used for testing or + development. When enabled, the SDK will contact any KAS URL found + in TDF files, potentially leaking credentials to malicious servers. + + Args: + ignore: Whether to ignore the KAS allowlist (default: True) + + Returns: + self: The builder instance for chaining + + """ + self._ignore_kas_allowlist = ignore + if ignore: + logger.warning( + "KAS allowlist validation is disabled. This is insecure and " + "should only be used for testing." + ) + return self + def _discover_token_endpoint_from_platform(self) -> None: """Discover token endpoint using OpenTDF platform configuration. @@ -356,6 +407,34 @@ def _get_token_from_client_credentials(self) -> str: f"Error during token acquisition: {e!s}" ) from e + def _create_kas_allowlist(self) -> KASAllowlist | None: + """Create the KAS allowlist based on builder configuration. + + Returns: + KASAllowlist configured based on builder settings, or None if + allowlist validation is disabled. + + """ + # If ignoring allowlist, return an allow-all instance + if self._ignore_kas_allowlist: + return KASAllowlist(allow_all=True) + + # If explicit allowlist provided, use it + if self._kas_allowlist_urls: + allowlist = KASAllowlist(self._kas_allowlist_urls) + # Also add the platform URL for convenience + if self.platform_endpoint: + allowlist.add(self.platform_endpoint) + allowlist.add(self.platform_endpoint.rstrip("/") + "/kas") + return allowlist + + # Default: create allowlist from platform URL only + if self.platform_endpoint: + return KASAllowlist.from_platform_url(self.platform_endpoint) + + # No platform endpoint set yet - return None and let SDK handle it + return None + def _create_services(self) -> SDK.Services: """Create service client instances. @@ -371,11 +450,15 @@ def _create_services(self) -> SDK.Services: ssl_verify = not self.insecure_skip_verify + # Create the KAS allowlist + kas_allowlist = self._create_kas_allowlist() + class ServicesImpl(SDK.Services): - def __init__(self, builder_instance): + def __init__(self, builder_instance, allowlist: KASAllowlist | None): self.closed = False self._ssl_verify = ssl_verify self._builder = builder_instance + self._kas_allowlist = allowlist def kas(self) -> KAS: """Return the KAS interface with SSL verification settings.""" @@ -394,6 +477,7 @@ def token_source(): token_source=token_source, sdk_ssl_verify=self._ssl_verify, use_plaintext=self._builder.use_plaintext, + kas_allowlist=self._kas_allowlist, ) return kas_impl @@ -403,7 +487,7 @@ def close(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() - return ServicesImpl(self) + return ServicesImpl(self, kas_allowlist) def build(self) -> SDK: """Build and return an SDK instance with configured properties. diff --git a/tests/test_kas_allowlist.py b/tests/test_kas_allowlist.py new file mode 100644 index 0000000..c2619ff --- /dev/null +++ b/tests/test_kas_allowlist.py @@ -0,0 +1,333 @@ +"""Tests for the KAS allowlist functionality. + +This module tests the KASAllowlist class which provides protection against +SSRF attacks where malicious TDF files could contain attacker-controlled +KAS URLs to steal OIDC credentials. +""" + +import logging + +import pytest + +from otdf_python.kas_allowlist import KASAllowlist +from otdf_python.kas_client import KASClient +from otdf_python.sdk import SDK +from otdf_python.sdk_builder import SDKBuilder + + +class TestKASAllowlist: + """Test cases for the KASAllowlist class.""" + + def test_empty_allowlist_rejects_all(self): + """Empty allowlist should reject all URLs.""" + allowlist = KASAllowlist() + assert not allowlist.is_allowed("https://example.com") + assert not allowlist.is_allowed("https://evil.com") + + def test_add_and_check_urls(self): + """Test adding URLs and checking if they are allowed.""" + allowlist = KASAllowlist() + allowlist.add("https://kas.example.com") + + assert allowlist.is_allowed("https://kas.example.com") + assert allowlist.is_allowed("https://kas.example.com/kas") + assert allowlist.is_allowed("https://kas.example.com:443/some/path") + assert not allowlist.is_allowed("https://evil.com") + assert not allowlist.is_allowed("https://kas.evil.com") + + def test_constructor_with_urls(self): + """Test initializing with a list of URLs.""" + allowlist = KASAllowlist( + ["https://kas1.example.com", "https://kas2.example.com"] + ) + + assert allowlist.is_allowed("https://kas1.example.com") + assert allowlist.is_allowed("https://kas2.example.com") + assert not allowlist.is_allowed("https://kas3.example.com") + + def test_allow_all_mode(self): + """Test allow_all mode allows everything.""" + allowlist = KASAllowlist(allow_all=True) + + assert allowlist.is_allowed("https://any.domain.com") + assert allowlist.is_allowed("https://evil.com") + assert allowlist.is_allowed("http://localhost:8080") + + def test_origin_normalization_strips_path(self): + """Test that paths are stripped when normalizing origins.""" + allowlist = KASAllowlist() + allowlist.add("https://kas.example.com/some/path") + + # Should allow any path on the same origin + assert allowlist.is_allowed("https://kas.example.com") + assert allowlist.is_allowed("https://kas.example.com/different/path") + assert allowlist.is_allowed("https://kas.example.com/kas") + + def test_origin_normalization_default_ports(self): + """Test that default ports are normalized correctly.""" + allowlist = KASAllowlist() + allowlist.add("https://kas.example.com") # Implies port 443 + + # Should match with explicit port 443 + assert allowlist.is_allowed("https://kas.example.com:443") + assert allowlist.is_allowed("https://kas.example.com:443/kas") + + # Should NOT match different port + assert not allowlist.is_allowed("https://kas.example.com:8443") + + def test_origin_normalization_http_default_port(self): + """Test that HTTP default port (80) is normalized correctly.""" + allowlist = KASAllowlist() + allowlist.add("http://kas.example.com") + + assert allowlist.is_allowed("http://kas.example.com:80") + assert not allowlist.is_allowed("http://kas.example.com:8080") + + def test_origin_normalization_explicit_port(self): + """Test that explicit ports are preserved.""" + allowlist = KASAllowlist() + allowlist.add("https://kas.example.com:8443") + + assert allowlist.is_allowed("https://kas.example.com:8443") + assert allowlist.is_allowed("https://kas.example.com:8443/kas") + assert not allowlist.is_allowed("https://kas.example.com:443") + assert not allowlist.is_allowed("https://kas.example.com") + + def test_origin_normalization_adds_scheme(self): + """Test that missing schemes default to https.""" + allowlist = KASAllowlist() + allowlist.add("kas.example.com") + + assert allowlist.is_allowed("https://kas.example.com") + assert not allowlist.is_allowed("http://kas.example.com") + + def test_case_insensitive_hostname(self): + """Test that hostname comparison is case-insensitive.""" + allowlist = KASAllowlist() + allowlist.add("https://KAS.Example.COM") + + assert allowlist.is_allowed("https://kas.example.com") + assert allowlist.is_allowed("https://KAS.EXAMPLE.COM") + + def test_validate_raises_exception(self): + """Test that validate() raises exception for disallowed URLs.""" + allowlist = KASAllowlist(["https://trusted.com"]) + + # Should not raise for allowed URL + allowlist.validate("https://trusted.com/kas") + + # Should raise for disallowed URL + with pytest.raises(SDK.KasAllowlistException) as excinfo: + allowlist.validate("https://evil.com/kas") + + assert "evil.com" in str(excinfo.value) + assert "trusted.com" in str(excinfo.value) + + def test_exception_contains_url_and_origins(self): + """Test that exception contains useful debugging information.""" + allowlist = KASAllowlist(["https://kas1.com", "https://kas2.com"]) + + with pytest.raises(SDK.KasAllowlistException) as excinfo: + allowlist.validate("https://attacker.com") + + exception = excinfo.value + assert exception.url == "https://attacker.com" + assert "https://kas1.com:443" in exception.allowed_origins + assert "https://kas2.com:443" in exception.allowed_origins + + def test_allowed_origins_property(self): + """Test the allowed_origins property returns a copy.""" + allowlist = KASAllowlist(["https://example.com"]) + + origins = allowlist.allowed_origins + origins.add("https://modified.com") + + # Original should not be modified + assert "https://modified.com:443" not in allowlist.allowed_origins + + def test_from_platform_url(self): + """Test creating allowlist from platform URL.""" + allowlist = KASAllowlist.from_platform_url("https://platform.example.com") + + # Should allow the platform URL + assert allowlist.is_allowed("https://platform.example.com") + # Should allow the /kas endpoint + assert allowlist.is_allowed("https://platform.example.com/kas") + + def test_from_platform_url_with_trailing_slash(self): + """Test creating allowlist from platform URL with trailing slash.""" + allowlist = KASAllowlist.from_platform_url("https://platform.example.com/") + + assert allowlist.is_allowed("https://platform.example.com") + assert allowlist.is_allowed("https://platform.example.com/kas") + + def test_allow_all_logs_warning(self, caplog): + """Test that allow_all mode logs a warning.""" + with caplog.at_level(logging.WARNING): + KASAllowlist(allow_all=True) + + assert "insecure" in caplog.text.lower() + + +class TestKASClientAllowlistIntegration: + """Test KASClient integration with allowlist.""" + + def test_client_without_allowlist_allows_all(self): + """Client without allowlist should allow all URLs (backward compatibility).""" + client = KASClient(kas_allowlist=None) + + # Should not raise - no validation without allowlist + normalized = client._normalize_kas_url("https://any.domain.com/kas") + assert "any.domain.com" in normalized + + def test_client_with_allowlist_validates(self): + """Client with allowlist should validate URLs.""" + allowlist = KASAllowlist(["https://trusted.kas.com"]) + client = KASClient(kas_allowlist=allowlist) + + # Should work for trusted URL + normalized = client._normalize_kas_url("https://trusted.kas.com/kas") + assert "trusted.kas.com" in normalized + + # Should raise for untrusted URL + with pytest.raises(SDK.KasAllowlistException): + client._normalize_kas_url("https://evil.com/kas") + + def test_client_allowlist_checked_before_normalization(self): + """Allowlist should be checked before any network operations.""" + allowlist = KASAllowlist(["https://trusted.com"]) + client = KASClient(kas_allowlist=allowlist) + + # Even malformed URLs should be rejected if not in allowlist + with pytest.raises(SDK.KasAllowlistException): + client._normalize_kas_url("https://evil.com") + + +class TestSDKBuilderAllowlist: + """Test SDKBuilder allowlist configuration.""" + + def test_with_kas_allowlist(self): + """Test configuring explicit allowlist.""" + builder = SDKBuilder() + result = builder.with_kas_allowlist(["https://kas1.com", "https://kas2.com"]) + + assert result is builder # Returns self for chaining + assert builder._kas_allowlist_urls == ["https://kas1.com", "https://kas2.com"] + + def test_with_ignore_kas_allowlist(self): + """Test ignoring allowlist.""" + builder = SDKBuilder() + result = builder.with_ignore_kas_allowlist(True) + + assert result is builder + assert builder._ignore_kas_allowlist is True + + def test_with_ignore_kas_allowlist_false(self): + """Test explicitly not ignoring allowlist.""" + builder = SDKBuilder() + builder.with_ignore_kas_allowlist(True) + builder.with_ignore_kas_allowlist(False) + + assert builder._ignore_kas_allowlist is False + + def test_create_kas_allowlist_from_platform_url(self): + """Test that allowlist is created from platform URL by default.""" + builder = SDKBuilder() + builder.set_platform_endpoint("https://platform.example.com") + + allowlist = builder._create_kas_allowlist() + + assert allowlist is not None + assert allowlist.is_allowed("https://platform.example.com") + assert allowlist.is_allowed("https://platform.example.com/kas") + assert not allowlist.is_allowed("https://other.com") + + def test_create_kas_allowlist_explicit(self): + """Test that explicit allowlist is used when provided.""" + builder = SDKBuilder() + builder.set_platform_endpoint("https://platform.example.com") + builder.with_kas_allowlist(["https://external-kas.com"]) + + allowlist = builder._create_kas_allowlist() + + assert allowlist is not None + # Should include explicit URL + assert allowlist.is_allowed("https://external-kas.com") + # Should also include platform URL + assert allowlist.is_allowed("https://platform.example.com") + + def test_create_kas_allowlist_ignore(self): + """Test that allow_all is returned when ignoring.""" + builder = SDKBuilder() + builder.set_platform_endpoint("https://platform.example.com") + builder.with_ignore_kas_allowlist(True) + + allowlist = builder._create_kas_allowlist() + + assert allowlist is not None + assert allowlist.allow_all is True + assert allowlist.is_allowed("https://any.domain.com") + + def test_allowlist_with_ignore_logs_warning(self, caplog): + """Test that ignoring allowlist logs a warning.""" + builder = SDKBuilder() + + with caplog.at_level(logging.WARNING): + builder.with_ignore_kas_allowlist(True) + + assert "insecure" in caplog.text.lower() + + +class TestSSRFProtection: + """Test SSRF protection scenarios.""" + + def test_malicious_tdf_url_rejected(self): + """Simulate a malicious TDF with attacker-controlled KAS URL.""" + # Simulate SDK configured with platform URL + allowlist = KASAllowlist.from_platform_url("https://legitimate-platform.com") + client = KASClient(kas_allowlist=allowlist) + + # Attacker crafts TDF with their KAS URL + malicious_kas_url = "https://attacker.evil.com/steal-tokens" + + # SDK should reject this URL before sending any credentials + with pytest.raises(SDK.KasAllowlistException) as excinfo: + client._normalize_kas_url(malicious_kas_url) + + assert "attacker.evil.com" in str(excinfo.value) + + def test_legitimate_kas_url_accepted(self): + """Verify legitimate KAS URLs are accepted.""" + allowlist = KASAllowlist.from_platform_url("https://platform.company.com") + client = KASClient(kas_allowlist=allowlist) + + # TDF with legitimate KAS URL should work + legitimate_url = "https://platform.company.com/kas" + normalized = client._normalize_kas_url(legitimate_url) + + assert "platform.company.com" in normalized + + def test_multi_kas_deployment(self): + """Test deployment with multiple KAS servers.""" + # Organization has multiple KAS servers + allowlist = KASAllowlist( + [ + "https://kas-primary.company.com", + "https://kas-secondary.company.com", + "https://kas-dr.company.com", + ] + ) + client = KASClient(kas_allowlist=allowlist) + + # All configured servers should work + for url in [ + "https://kas-primary.company.com/kas", + "https://kas-secondary.company.com/kas", + "https://kas-dr.company.com/kas", + ]: + normalized = client._normalize_kas_url(url) + assert normalized is not None + + # Unknown server should be rejected + with pytest.raises(SDK.KasAllowlistException): + client._normalize_kas_url("https://unknown-kas.other.com/kas")