diff --git a/src/workos/_base_client.py b/src/workos/_base_client.py index 83c18d39..1becc09d 100644 --- a/src/workos/_base_client.py +++ b/src/workos/_base_client.py @@ -10,6 +10,7 @@ from datetime import datetime, timezone from email.utils import parsedate_to_datetime from typing import Any, Dict, Optional, Type, cast, overload +from urllib.parse import quote import httpx @@ -53,8 +54,15 @@ def __init__( request_timeout: Optional[int] = None, jwt_leeway: float = 0.0, max_retries: int = MAX_RETRIES, + is_public: bool = False, ) -> None: - self._api_key = api_key or os.environ.get("WORKOS_API_KEY") + self._is_public = is_public + # Public clients (PKCE / browser / mobile / CLI) must never attach + # an API key, even if WORKOS_API_KEY is present in the environment. + if is_public: + self._api_key: Optional[str] = None + else: + self._api_key = api_key or os.environ.get("WORKOS_API_KEY") self.client_id = client_id or os.environ.get("WORKOS_CLIENT_ID") if not self._api_key and not self.client_id: raise ValueError( @@ -128,6 +136,21 @@ def _resolve_base_url(self, request_options: Optional[RequestOptions]) -> str: return str(base_url).rstrip("/") return self._base_url.rstrip("/") + @staticmethod + def _encode_path(path: str) -> str: + """Percent-encode each path segment to prevent path-traversal/injection. + + Splits on ``/`` and applies ``urllib.parse.quote(seg, safe='')`` to each + segment so that user-supplied IDs containing reserved characters (``/``, + ``?``, ``#``, ``%``, etc.) cannot escape their intended segment. The + leading slash (if any) is preserved. + """ + if not path: + return path + leading = "/" if path.startswith("/") else "" + body = path[1:] if leading else path + return leading + "/".join(quote(seg, safe="") for seg in body.split("/")) + def _resolve_timeout(self, request_options: Optional[RequestOptions]) -> float: timeout = self._request_timeout if request_options: @@ -332,6 +355,7 @@ def __init__( request_timeout: Optional[int] = None, jwt_leeway: float = 0.0, max_retries: int = MAX_RETRIES, + is_public: bool = False, ) -> None: """Initialize the WorkOS client. @@ -342,6 +366,10 @@ def __init__( request_timeout: HTTP request timeout in seconds. Falls back to WORKOS_REQUEST_TIMEOUT or 60. jwt_leeway: JWT clock skew leeway in seconds. max_retries: Maximum number of retries for failed requests. Defaults to 3. + is_public: When True, mark this client as public (PKCE / browser + / mobile / CLI). The API key is forced to None and the + ``WORKOS_API_KEY`` environment variable is ignored. Use + ``create_public_client`` instead of setting this directly. Raises: ValueError: If neither api_key nor client_id is provided, directly or via environment variables. @@ -353,6 +381,7 @@ def __init__( request_timeout=request_timeout, jwt_leeway=jwt_leeway, max_retries=max_retries, + is_public=is_public, ) self._client = httpx.Client( timeout=self._request_timeout, follow_redirects=True @@ -406,7 +435,7 @@ def request( request_options: Optional[RequestOptions] = None, ) -> Any: """Make an HTTP request with retry logic.""" - url = f"{self._resolve_base_url(request_options)}/{path}" + url = f"{self._resolve_base_url(request_options)}/{self._encode_path(path).lstrip('/')}" headers = self._build_headers(method, idempotency_key, request_options) timeout = self._resolve_timeout(request_options) max_retries = self._resolve_max_retries(request_options) @@ -557,6 +586,7 @@ def __init__( request_timeout: Optional[int] = None, jwt_leeway: float = 0.0, max_retries: int = MAX_RETRIES, + is_public: bool = False, ) -> None: """Initialize the async WorkOS client. @@ -578,6 +608,7 @@ def __init__( request_timeout=request_timeout, jwt_leeway=jwt_leeway, max_retries=max_retries, + is_public=is_public, ) self._client = httpx.AsyncClient( timeout=self._request_timeout, follow_redirects=True @@ -631,7 +662,7 @@ async def request( request_options: Optional[RequestOptions] = None, ) -> Any: """Make an async HTTP request with retry logic.""" - url = f"{self._resolve_base_url(request_options)}/{path}" + url = f"{self._resolve_base_url(request_options)}/{self._encode_path(path).lstrip('/')}" headers = self._build_headers(method, idempotency_key, request_options) timeout = self._resolve_timeout(request_options) max_retries = self._resolve_max_retries(request_options) diff --git a/src/workos/actions.py b/src/workos/actions.py index cd2b3b7f..258d7986 100644 --- a/src/workos/actions.py +++ b/src/workos/actions.py @@ -44,7 +44,7 @@ def _verify_signature( timestamp_in_seconds = int(issued_timestamp) / 1000 seconds_since_issued = current_time - timestamp_in_seconds - if seconds_since_issued > tolerance: + if abs(seconds_since_issued) > tolerance: raise ValueError("Timestamp outside the tolerance zone") body_str = payload.decode("utf-8") if isinstance(payload, bytes) else payload diff --git a/src/workos/public_client.py b/src/workos/public_client.py index 057b7e40..b4bca338 100644 --- a/src/workos/public_client.py +++ b/src/workos/public_client.py @@ -33,7 +33,9 @@ def create_public_client( from ._client import WorkOSClient return WorkOSClient( + api_key=None, client_id=client_id, base_url=base_url, request_timeout=request_timeout, + is_public=True, ) diff --git a/src/workos/session.py b/src/workos/session.py index fbee198e..f289c5eb 100644 --- a/src/workos/session.py +++ b/src/workos/session.py @@ -24,6 +24,20 @@ from cryptography.fernet import Fernet from jwt import PyJWKClient +from ._errors import ( + AuthenticationError, + AuthenticationMethodNotAllowedError, + EmailVerificationRequiredError, + MfaChallengeError, + MfaEnrollmentError, + OrganizationAuthMethodsRequiredError, + OrganizationSelectionRequiredError, + RadarChallengeError, + SsoRequiredError, + WorkOSConnectionError, + WorkOSTimeoutError, +) + if TYPE_CHECKING: from ._client import AsyncWorkOSClient, WorkOSClient @@ -37,6 +51,51 @@ class AuthenticateWithSessionCookieFailureReason(Enum): INVALID_JWT = "invalid_jwt" INVALID_SESSION_COOKIE = "invalid_session_cookie" NO_SESSION_COOKIE_PROVIDED = "no_session_cookie_provided" + MFA_CHALLENGE_REQUIRED = "mfa_challenge_required" + MFA_ENROLLMENT_REQUIRED = "mfa_enrollment_required" + SSO_REQUIRED = "sso_required" + EMAIL_VERIFICATION_REQUIRED = "email_verification_required" + ORGANIZATION_SELECTION_REQUIRED = "organization_selection_required" + ORGANIZATION_AUTH_METHODS_REQUIRED = "organization_auth_methods_required" + AUTHENTICATION_METHOD_NOT_ALLOWED = "authentication_method_not_allowed" + RADAR_CHALLENGE_REQUIRED = "radar_challenge_required" + REFRESH_DENIED = "refresh_denied" + REFRESH_NETWORK_ERROR = "refresh_network_error" + + +def _map_refresh_exception_to_reason( + exc: Exception, +) -> Union[AuthenticateWithSessionCookieFailureReason, str]: + """Map an exception raised by a refresh request to a structured reason. + + Falls back to ``str(exc)`` for unknown errors so callers retain the + pre-existing string form for diagnostics. + """ + if isinstance(exc, MfaChallengeError): + return AuthenticateWithSessionCookieFailureReason.MFA_CHALLENGE_REQUIRED + if isinstance(exc, MfaEnrollmentError): + return AuthenticateWithSessionCookieFailureReason.MFA_ENROLLMENT_REQUIRED + if isinstance(exc, SsoRequiredError): + return AuthenticateWithSessionCookieFailureReason.SSO_REQUIRED + if isinstance(exc, EmailVerificationRequiredError): + return AuthenticateWithSessionCookieFailureReason.EMAIL_VERIFICATION_REQUIRED + if isinstance(exc, OrganizationSelectionRequiredError): + return ( + AuthenticateWithSessionCookieFailureReason.ORGANIZATION_SELECTION_REQUIRED + ) + if isinstance(exc, OrganizationAuthMethodsRequiredError): + return AuthenticateWithSessionCookieFailureReason.ORGANIZATION_AUTH_METHODS_REQUIRED + if isinstance(exc, AuthenticationMethodNotAllowedError): + return ( + AuthenticateWithSessionCookieFailureReason.AUTHENTICATION_METHOD_NOT_ALLOWED + ) + if isinstance(exc, RadarChallengeError): + return AuthenticateWithSessionCookieFailureReason.RADAR_CHALLENGE_REQUIRED + if isinstance(exc, AuthenticationError): + return AuthenticateWithSessionCookieFailureReason.REFRESH_DENIED + if isinstance(exc, (WorkOSConnectionError, WorkOSTimeoutError)): + return AuthenticateWithSessionCookieFailureReason.REFRESH_NETWORK_ERROR + return str(exc) @dataclass(slots=True) @@ -328,7 +387,7 @@ def refresh( ) except Exception as e: return RefreshWithSessionCookieErrorResponse( - authenticated=False, reason=str(e) + authenticated=False, reason=_map_refresh_exception_to_reason(e) ) def get_logout_url(self, return_to: Optional[str] = None) -> str: @@ -507,7 +566,7 @@ async def refresh( ) except Exception as e: return RefreshWithSessionCookieErrorResponse( - authenticated=False, reason=str(e) + authenticated=False, reason=_map_refresh_exception_to_reason(e) ) async def get_logout_url(self, return_to: Optional[str] = None) -> str: diff --git a/src/workos/vault.py b/src/workos/vault.py index b5d50895..06740d2e 100644 --- a/src/workos/vault.py +++ b/src/workos/vault.py @@ -282,10 +282,12 @@ def _decode_u32_leb128(buf: bytes) -> Tuple[int, int]: res = 0 bit = 0 for i, b in enumerate(buf): - if i > 4: + if i >= 4 and (b & 0x80) != 0: raise ValueError("LEB128 integer overflow (was more than 4 bytes)") res |= (b & 0x7F) << (7 * bit) if (b & 0x80) == 0: + if res > 0xFFFFFFFF: + raise ValueError("LEB128 integer overflow (exceeds 32 bits)") return res, i + 1 bit += 1 raise ValueError("LEB128 integer not found") diff --git a/src/workos/webhooks/_resource.py b/src/workos/webhooks/_resource.py index 0b0d11fb..be93fa58 100644 --- a/src/workos/webhooks/_resource.py +++ b/src/workos/webhooks/_resource.py @@ -263,7 +263,7 @@ def verify_header( timestamp_in_seconds = int(issued_timestamp) / 1000 seconds_since_issued = current_time - timestamp_in_seconds - if seconds_since_issued > max_seconds_since_issued: + if abs(seconds_since_issued) > max_seconds_since_issued: raise ValueError("Timestamp outside the tolerance zone") body_str = ( @@ -520,7 +520,7 @@ def verify_header( timestamp_in_seconds = int(issued_timestamp) / 1000 seconds_since_issued = current_time - timestamp_in_seconds - if seconds_since_issued > max_seconds_since_issued: + if abs(seconds_since_issued) > max_seconds_since_issued: raise ValueError("Timestamp outside the tolerance zone") body_str = ( diff --git a/src/workos/webhooks/_verification.py b/src/workos/webhooks/_verification.py index 3dc99582..3ec82bef 100644 --- a/src/workos/webhooks/_verification.py +++ b/src/workos/webhooks/_verification.py @@ -78,12 +78,12 @@ def verify_header( issued_timestamp = issued_timestamp[2:] signature_hash = signature_hash[3:] - max_seconds_since_issued = tolerance or DEFAULT_TOLERANCE + max_seconds_since_issued = tolerance if tolerance is not None else DEFAULT_TOLERANCE current_time = time.time() timestamp_in_seconds = int(issued_timestamp) / 1000 seconds_since_issued = current_time - timestamp_in_seconds - if seconds_since_issued > max_seconds_since_issued: + if abs(seconds_since_issued) > max_seconds_since_issued: raise ValueError("Timestamp outside the tolerance zone") unhashed_string = "{0}.{1}".format(issued_timestamp, event_body.decode("utf-8")) diff --git a/tests/test_actions.py b/tests/test_actions.py index 1d65d8f3..eb107905 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -63,6 +63,17 @@ def test_verify_header_stale_timestamp(self): tolerance=30, ) + def test_verify_header_future_timestamp(self): + future_ts = int((time.time() + 60) * 1000) + sig = _make_sig_header(SAMPLE_ACTION_PAYLOAD, SECRET, future_ts) + with pytest.raises(ValueError, match="tolerance zone"): + self.actions.verify_header( + payload=SAMPLE_ACTION_PAYLOAD, + sig_header=sig, + secret=SECRET, + tolerance=30, + ) + def test_verify_header_custom_tolerance(self): old_ts = int((time.time() - 10) * 1000) sig = _make_sig_header(SAMPLE_ACTION_PAYLOAD, SECRET, old_ts) diff --git a/tests/test_session.py b/tests/test_session.py index 3355bfb2..15e44e78 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -7,6 +7,19 @@ from cryptography.hazmat.primitives.asymmetric import rsa from workos import WorkOSClient +from workos._errors import ( + AuthenticationError, + AuthenticationMethodNotAllowedError, + EmailVerificationRequiredError, + MfaChallengeError, + MfaEnrollmentError, + OrganizationAuthMethodsRequiredError, + OrganizationSelectionRequiredError, + RadarChallengeError, + SsoRequiredError, + WorkOSConnectionError, + WorkOSTimeoutError, +) from workos.session import ( AsyncSession, AuthenticateWithSessionCookieErrorResponse, @@ -14,6 +27,7 @@ AuthenticateWithSessionCookieSuccessResponse, RefreshWithSessionCookieErrorResponse, Session, + _map_refresh_exception_to_reason, seal_data, seal_session_from_auth_response, unseal_data, @@ -212,6 +226,64 @@ def test_session_refresh_missing_refresh_token(self): assert isinstance(result, RefreshWithSessionCookieErrorResponse) +class TestMapRefreshExceptionToReason: + @pytest.mark.parametrize( + "exc, expected", + [ + ( + MfaChallengeError("mfa challenge"), + AuthenticateWithSessionCookieFailureReason.MFA_CHALLENGE_REQUIRED, + ), + ( + MfaEnrollmentError("mfa enrollment"), + AuthenticateWithSessionCookieFailureReason.MFA_ENROLLMENT_REQUIRED, + ), + ( + SsoRequiredError("sso required"), + AuthenticateWithSessionCookieFailureReason.SSO_REQUIRED, + ), + ( + EmailVerificationRequiredError("email verification required"), + AuthenticateWithSessionCookieFailureReason.EMAIL_VERIFICATION_REQUIRED, + ), + ( + OrganizationSelectionRequiredError("org selection required"), + AuthenticateWithSessionCookieFailureReason.ORGANIZATION_SELECTION_REQUIRED, + ), + ( + OrganizationAuthMethodsRequiredError("org auth methods required"), + AuthenticateWithSessionCookieFailureReason.ORGANIZATION_AUTH_METHODS_REQUIRED, + ), + ( + AuthenticationMethodNotAllowedError("method not allowed"), + AuthenticateWithSessionCookieFailureReason.AUTHENTICATION_METHOD_NOT_ALLOWED, + ), + ( + RadarChallengeError("radar challenge"), + AuthenticateWithSessionCookieFailureReason.RADAR_CHALLENGE_REQUIRED, + ), + ( + AuthenticationError("unauthorized"), + AuthenticateWithSessionCookieFailureReason.REFRESH_DENIED, + ), + ( + WorkOSConnectionError("connection failed"), + AuthenticateWithSessionCookieFailureReason.REFRESH_NETWORK_ERROR, + ), + ( + WorkOSTimeoutError("timeout"), + AuthenticateWithSessionCookieFailureReason.REFRESH_NETWORK_ERROR, + ), + ], + ) + def test_known_exceptions_map_to_reason(self, exc, expected): + assert _map_refresh_exception_to_reason(exc) == expected + + def test_unknown_exception_falls_back_to_string(self): + result = _map_refresh_exception_to_reason(RuntimeError("boom")) + assert result == "boom" + + @pytest.mark.asyncio class TestAsyncSession: def _mock_jwks(self, public_key): diff --git a/tests/test_webhook_verification.py b/tests/test_webhook_verification.py index aaa2ea05..ba98d5f1 100644 --- a/tests/test_webhook_verification.py +++ b/tests/test_webhook_verification.py @@ -124,6 +124,17 @@ def test_verify_header_stale_timestamp(self, workos): tolerance=180, ) + def test_verify_header_future_timestamp(self, workos): + future_ts = int((time.time() + 300) * 1000) + sig = _make_sig_header(SAMPLE_EVENT, SECRET, future_ts) + with pytest.raises(ValueError, match="tolerance zone"): + workos.webhooks.verify_header( + event_body=SAMPLE_EVENT, + event_signature=sig, + secret=SECRET, + tolerance=180, + ) + class TestStandaloneVerifyEvent: def test_standalone_verify_event(self): @@ -157,3 +168,25 @@ def test_standalone_verify_header_invalid(self): event_signature=sig, secret=SECRET, ) + + def test_standalone_verify_header_future_timestamp(self): + future_ts = int((time.time() + 300) * 1000) + sig = _make_sig_header(SAMPLE_EVENT, SECRET, future_ts) + with pytest.raises(ValueError, match="tolerance zone"): + standalone_verify_header( + event_body=SAMPLE_EVENT.encode("utf-8"), + event_signature=sig, + secret=SECRET, + tolerance=180, + ) + + def test_standalone_verify_header_tolerance_zero_rejects_old_timestamp(self): + old_ts = int((time.time() - 1) * 1000) + sig = _make_sig_header(SAMPLE_EVENT, SECRET, old_ts) + with pytest.raises(ValueError, match="tolerance zone"): + standalone_verify_header( + event_body=SAMPLE_EVENT.encode("utf-8"), + event_signature=sig, + secret=SECRET, + tolerance=0, + )