Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions src/workos/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from typing import Any, Dict, Optional, Type, cast, overload
from urllib.parse import quote

import httpx

Expand Down Expand Up @@ -53,8 +54,15 @@ def __init__(
request_timeout: Optional[int] = None,
jwt_leeway: float = 0.0,
max_retries: int = MAX_RETRIES,
is_public: bool = False,
) -> None:
self._api_key = api_key or os.environ.get("WORKOS_API_KEY")
self._is_public = is_public
# Public clients (PKCE / browser / mobile / CLI) must never attach
# an API key, even if WORKOS_API_KEY is present in the environment.
if is_public:
self._api_key: Optional[str] = None
else:
self._api_key = api_key or os.environ.get("WORKOS_API_KEY")
self.client_id = client_id or os.environ.get("WORKOS_CLIENT_ID")
if not self._api_key and not self.client_id:
raise ValueError(
Expand Down Expand Up @@ -128,6 +136,21 @@ def _resolve_base_url(self, request_options: Optional[RequestOptions]) -> str:
return str(base_url).rstrip("/")
return self._base_url.rstrip("/")

@staticmethod
def _encode_path(path: str) -> str:
"""Percent-encode each path segment to prevent path-traversal/injection.

Splits on ``/`` and applies ``urllib.parse.quote(seg, safe='')`` to each
segment so that user-supplied IDs containing reserved characters (``/``,
``?``, ``#``, ``%``, etc.) cannot escape their intended segment. The
leading slash (if any) is preserved.
"""
if not path:
return path
leading = "/" if path.startswith("/") else ""
body = path[1:] if leading else path
return leading + "/".join(quote(seg, safe="") for seg in body.split("/"))

def _resolve_timeout(self, request_options: Optional[RequestOptions]) -> float:
timeout = self._request_timeout
if request_options:
Expand Down Expand Up @@ -332,6 +355,7 @@ def __init__(
request_timeout: Optional[int] = None,
jwt_leeway: float = 0.0,
max_retries: int = MAX_RETRIES,
is_public: bool = False,
) -> None:
"""Initialize the WorkOS client.

Expand All @@ -342,6 +366,10 @@ def __init__(
request_timeout: HTTP request timeout in seconds. Falls back to WORKOS_REQUEST_TIMEOUT or 60.
jwt_leeway: JWT clock skew leeway in seconds.
max_retries: Maximum number of retries for failed requests. Defaults to 3.
is_public: When True, mark this client as public (PKCE / browser
/ mobile / CLI). The API key is forced to None and the
``WORKOS_API_KEY`` environment variable is ignored. Use
``create_public_client`` instead of setting this directly.

Raises:
ValueError: If neither api_key nor client_id is provided, directly or via environment variables.
Expand All @@ -353,6 +381,7 @@ def __init__(
request_timeout=request_timeout,
jwt_leeway=jwt_leeway,
max_retries=max_retries,
is_public=is_public,
)
self._client = httpx.Client(
timeout=self._request_timeout, follow_redirects=True
Expand Down Expand Up @@ -406,7 +435,7 @@ def request(
request_options: Optional[RequestOptions] = None,
) -> Any:
"""Make an HTTP request with retry logic."""
url = f"{self._resolve_base_url(request_options)}/{path}"
url = f"{self._resolve_base_url(request_options)}/{self._encode_path(path).lstrip('/')}"
headers = self._build_headers(method, idempotency_key, request_options)
timeout = self._resolve_timeout(request_options)
max_retries = self._resolve_max_retries(request_options)
Expand Down Expand Up @@ -557,6 +586,7 @@ def __init__(
request_timeout: Optional[int] = None,
jwt_leeway: float = 0.0,
max_retries: int = MAX_RETRIES,
is_public: bool = False,
) -> None:
"""Initialize the async WorkOS client.

Expand All @@ -578,6 +608,7 @@ def __init__(
request_timeout=request_timeout,
jwt_leeway=jwt_leeway,
max_retries=max_retries,
is_public=is_public,
)
self._client = httpx.AsyncClient(
timeout=self._request_timeout, follow_redirects=True
Expand Down Expand Up @@ -631,7 +662,7 @@ async def request(
request_options: Optional[RequestOptions] = None,
) -> Any:
"""Make an async HTTP request with retry logic."""
url = f"{self._resolve_base_url(request_options)}/{path}"
url = f"{self._resolve_base_url(request_options)}/{self._encode_path(path).lstrip('/')}"
headers = self._build_headers(method, idempotency_key, request_options)
timeout = self._resolve_timeout(request_options)
max_retries = self._resolve_max_retries(request_options)
Expand Down
2 changes: 1 addition & 1 deletion src/workos/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _verify_signature(
timestamp_in_seconds = int(issued_timestamp) / 1000
seconds_since_issued = current_time - timestamp_in_seconds

if seconds_since_issued > tolerance:
if abs(seconds_since_issued) > tolerance:
raise ValueError("Timestamp outside the tolerance zone")

body_str = payload.decode("utf-8") if isinstance(payload, bytes) else payload
Expand Down
2 changes: 2 additions & 0 deletions src/workos/public_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def create_public_client(
from ._client import WorkOSClient

return WorkOSClient(
api_key=None,
client_id=client_id,
base_url=base_url,
request_timeout=request_timeout,
is_public=True,
)
63 changes: 61 additions & 2 deletions src/workos/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@
from cryptography.fernet import Fernet
from jwt import PyJWKClient

from ._errors import (
AuthenticationError,
AuthenticationMethodNotAllowedError,
EmailVerificationRequiredError,
MfaChallengeError,
MfaEnrollmentError,
OrganizationAuthMethodsRequiredError,
OrganizationSelectionRequiredError,
RadarChallengeError,
SsoRequiredError,
WorkOSConnectionError,
WorkOSTimeoutError,
)

if TYPE_CHECKING:
from ._client import AsyncWorkOSClient, WorkOSClient

Expand All @@ -37,6 +51,51 @@ class AuthenticateWithSessionCookieFailureReason(Enum):
INVALID_JWT = "invalid_jwt"
INVALID_SESSION_COOKIE = "invalid_session_cookie"
NO_SESSION_COOKIE_PROVIDED = "no_session_cookie_provided"
MFA_CHALLENGE_REQUIRED = "mfa_challenge_required"
MFA_ENROLLMENT_REQUIRED = "mfa_enrollment_required"
SSO_REQUIRED = "sso_required"
EMAIL_VERIFICATION_REQUIRED = "email_verification_required"
ORGANIZATION_SELECTION_REQUIRED = "organization_selection_required"
ORGANIZATION_AUTH_METHODS_REQUIRED = "organization_auth_methods_required"
AUTHENTICATION_METHOD_NOT_ALLOWED = "authentication_method_not_allowed"
RADAR_CHALLENGE_REQUIRED = "radar_challenge_required"
REFRESH_DENIED = "refresh_denied"
REFRESH_NETWORK_ERROR = "refresh_network_error"


def _map_refresh_exception_to_reason(
exc: Exception,
) -> Union[AuthenticateWithSessionCookieFailureReason, str]:
"""Map an exception raised by a refresh request to a structured reason.

Falls back to ``str(exc)`` for unknown errors so callers retain the
pre-existing string form for diagnostics.
"""
if isinstance(exc, MfaChallengeError):
return AuthenticateWithSessionCookieFailureReason.MFA_CHALLENGE_REQUIRED
if isinstance(exc, MfaEnrollmentError):
return AuthenticateWithSessionCookieFailureReason.MFA_ENROLLMENT_REQUIRED
if isinstance(exc, SsoRequiredError):
return AuthenticateWithSessionCookieFailureReason.SSO_REQUIRED
if isinstance(exc, EmailVerificationRequiredError):
return AuthenticateWithSessionCookieFailureReason.EMAIL_VERIFICATION_REQUIRED
if isinstance(exc, OrganizationSelectionRequiredError):
return (
AuthenticateWithSessionCookieFailureReason.ORGANIZATION_SELECTION_REQUIRED
)
if isinstance(exc, OrganizationAuthMethodsRequiredError):
return AuthenticateWithSessionCookieFailureReason.ORGANIZATION_AUTH_METHODS_REQUIRED
if isinstance(exc, AuthenticationMethodNotAllowedError):
return (
AuthenticateWithSessionCookieFailureReason.AUTHENTICATION_METHOD_NOT_ALLOWED
)
if isinstance(exc, RadarChallengeError):
return AuthenticateWithSessionCookieFailureReason.RADAR_CHALLENGE_REQUIRED
if isinstance(exc, AuthenticationError):
return AuthenticateWithSessionCookieFailureReason.REFRESH_DENIED
if isinstance(exc, (WorkOSConnectionError, WorkOSTimeoutError)):
return AuthenticateWithSessionCookieFailureReason.REFRESH_NETWORK_ERROR
return str(exc)


@dataclass(slots=True)
Comment thread
greptile-apps[bot] marked this conversation as resolved.
Expand Down Expand Up @@ -328,7 +387,7 @@ def refresh(
)
except Exception as e:
return RefreshWithSessionCookieErrorResponse(
authenticated=False, reason=str(e)
authenticated=False, reason=_map_refresh_exception_to_reason(e)
)

def get_logout_url(self, return_to: Optional[str] = None) -> str:
Expand Down Expand Up @@ -507,7 +566,7 @@ async def refresh(
)
except Exception as e:
return RefreshWithSessionCookieErrorResponse(
authenticated=False, reason=str(e)
authenticated=False, reason=_map_refresh_exception_to_reason(e)
)

async def get_logout_url(self, return_to: Optional[str] = None) -> str:
Expand Down
4 changes: 3 additions & 1 deletion src/workos/vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,12 @@ def _decode_u32_leb128(buf: bytes) -> Tuple[int, int]:
res = 0
bit = 0
for i, b in enumerate(buf):
if i > 4:
if i >= 4 and (b & 0x80) != 0:
raise ValueError("LEB128 integer overflow (was more than 4 bytes)")
res |= (b & 0x7F) << (7 * bit)
if (b & 0x80) == 0:
if res > 0xFFFFFFFF:
raise ValueError("LEB128 integer overflow (exceeds 32 bits)")
return res, i + 1
bit += 1
raise ValueError("LEB128 integer not found")
Expand Down
4 changes: 2 additions & 2 deletions src/workos/webhooks/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def verify_header(
timestamp_in_seconds = int(issued_timestamp) / 1000
seconds_since_issued = current_time - timestamp_in_seconds

if seconds_since_issued > max_seconds_since_issued:
if abs(seconds_since_issued) > max_seconds_since_issued:
raise ValueError("Timestamp outside the tolerance zone")

body_str = (
Expand Down Expand Up @@ -520,7 +520,7 @@ def verify_header(
timestamp_in_seconds = int(issued_timestamp) / 1000
seconds_since_issued = current_time - timestamp_in_seconds

if seconds_since_issued > max_seconds_since_issued:
if abs(seconds_since_issued) > max_seconds_since_issued:
raise ValueError("Timestamp outside the tolerance zone")

body_str = (
Expand Down
4 changes: 2 additions & 2 deletions src/workos/webhooks/_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ def verify_header(

issued_timestamp = issued_timestamp[2:]
signature_hash = signature_hash[3:]
max_seconds_since_issued = tolerance or DEFAULT_TOLERANCE
max_seconds_since_issued = tolerance if tolerance is not None else DEFAULT_TOLERANCE
current_time = time.time()
timestamp_in_seconds = int(issued_timestamp) / 1000
seconds_since_issued = current_time - timestamp_in_seconds

if seconds_since_issued > max_seconds_since_issued:
if abs(seconds_since_issued) > max_seconds_since_issued:
raise ValueError("Timestamp outside the tolerance zone")

unhashed_string = "{0}.{1}".format(issued_timestamp, event_body.decode("utf-8"))
Expand Down
11 changes: 11 additions & 0 deletions tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ def test_verify_header_stale_timestamp(self):
tolerance=30,
)

def test_verify_header_future_timestamp(self):
future_ts = int((time.time() + 60) * 1000)
sig = _make_sig_header(SAMPLE_ACTION_PAYLOAD, SECRET, future_ts)
with pytest.raises(ValueError, match="tolerance zone"):
self.actions.verify_header(
payload=SAMPLE_ACTION_PAYLOAD,
sig_header=sig,
secret=SECRET,
tolerance=30,
)

def test_verify_header_custom_tolerance(self):
old_ts = int((time.time() - 10) * 1000)
sig = _make_sig_header(SAMPLE_ACTION_PAYLOAD, SECRET, old_ts)
Expand Down
72 changes: 72 additions & 0 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,27 @@
from cryptography.hazmat.primitives.asymmetric import rsa

from workos import WorkOSClient
from workos._errors import (
AuthenticationError,
AuthenticationMethodNotAllowedError,
EmailVerificationRequiredError,
MfaChallengeError,
MfaEnrollmentError,
OrganizationAuthMethodsRequiredError,
OrganizationSelectionRequiredError,
RadarChallengeError,
SsoRequiredError,
WorkOSConnectionError,
WorkOSTimeoutError,
)
from workos.session import (
AsyncSession,
AuthenticateWithSessionCookieErrorResponse,
AuthenticateWithSessionCookieFailureReason,
AuthenticateWithSessionCookieSuccessResponse,
RefreshWithSessionCookieErrorResponse,
Session,
_map_refresh_exception_to_reason,
seal_data,
seal_session_from_auth_response,
unseal_data,
Expand Down Expand Up @@ -212,6 +226,64 @@ def test_session_refresh_missing_refresh_token(self):
assert isinstance(result, RefreshWithSessionCookieErrorResponse)


class TestMapRefreshExceptionToReason:
@pytest.mark.parametrize(
"exc, expected",
[
(
MfaChallengeError("mfa challenge"),
AuthenticateWithSessionCookieFailureReason.MFA_CHALLENGE_REQUIRED,
),
(
MfaEnrollmentError("mfa enrollment"),
AuthenticateWithSessionCookieFailureReason.MFA_ENROLLMENT_REQUIRED,
),
(
SsoRequiredError("sso required"),
AuthenticateWithSessionCookieFailureReason.SSO_REQUIRED,
),
(
EmailVerificationRequiredError("email verification required"),
AuthenticateWithSessionCookieFailureReason.EMAIL_VERIFICATION_REQUIRED,
),
(
OrganizationSelectionRequiredError("org selection required"),
AuthenticateWithSessionCookieFailureReason.ORGANIZATION_SELECTION_REQUIRED,
),
(
OrganizationAuthMethodsRequiredError("org auth methods required"),
AuthenticateWithSessionCookieFailureReason.ORGANIZATION_AUTH_METHODS_REQUIRED,
),
(
AuthenticationMethodNotAllowedError("method not allowed"),
AuthenticateWithSessionCookieFailureReason.AUTHENTICATION_METHOD_NOT_ALLOWED,
),
(
RadarChallengeError("radar challenge"),
AuthenticateWithSessionCookieFailureReason.RADAR_CHALLENGE_REQUIRED,
),
(
AuthenticationError("unauthorized"),
AuthenticateWithSessionCookieFailureReason.REFRESH_DENIED,
),
(
WorkOSConnectionError("connection failed"),
AuthenticateWithSessionCookieFailureReason.REFRESH_NETWORK_ERROR,
),
(
WorkOSTimeoutError("timeout"),
AuthenticateWithSessionCookieFailureReason.REFRESH_NETWORK_ERROR,
),
],
)
def test_known_exceptions_map_to_reason(self, exc, expected):
assert _map_refresh_exception_to_reason(exc) == expected

def test_unknown_exception_falls_back_to_string(self):
result = _map_refresh_exception_to_reason(RuntimeError("boom"))
assert result == "boom"


@pytest.mark.asyncio
class TestAsyncSession:
def _mock_jwks(self, public_key):
Expand Down
Loading
Loading