-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathworkos_auth.py
More file actions
226 lines (186 loc) · 7.73 KB
/
workos_auth.py
File metadata and controls
226 lines (186 loc) · 7.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""
WorkOS Authentication Module
This module provides WorkOS JWT token authentication for protected routes.
"""
from fastapi import HTTPException, Request
from pydantic import BaseModel
from loguru import logger
from typing import Any
import jwt
import sys
from jwt.exceptions import DecodeError, InvalidTokenError, PyJWKClientError
from jwt import PyJWKClient
from workos import WorkOSClient
from src.utils.logging_config import setup_logging
from common import global_config
# Setup logging at module import
setup_logging()
# Initialize WorkOS JWKS client (cached at module level)
WORKOS_JWKS_URL = f"https://api.workos.com/sso/jwks/{global_config.WORKOS_CLIENT_ID}"
WORKOS_ISSUER = "https://api.workos.com"
WORKOS_ACCESS_ISSUER = (
f"{WORKOS_ISSUER}/user_management/{global_config.WORKOS_CLIENT_ID}"
)
WORKOS_ALLOWED_ISSUERS = [WORKOS_ISSUER, WORKOS_ACCESS_ISSUER]
WORKOS_AUDIENCE = global_config.WORKOS_CLIENT_ID
# Create JWKS client instance (will cache keys automatically)
_jwks_client: PyJWKClient | None = None
# WorkOS API client (cached)
_workos_client: WorkOSClient | None = None
def get_jwks_client() -> PyJWKClient:
"""Get or create the WorkOS JWKS client instance."""
global _jwks_client
if _jwks_client is None:
_jwks_client = PyJWKClient(WORKOS_JWKS_URL)
return _jwks_client
def get_workos_client() -> WorkOSClient:
"""Get or create the WorkOS API client instance."""
global _workos_client
if _workos_client is None:
_workos_client = WorkOSClient(
api_key=global_config.WORKOS_API_KEY,
client_id=global_config.WORKOS_CLIENT_ID,
)
return _workos_client
class WorkOSUser(BaseModel):
"""WorkOS user model"""
id: str # noqa
email: str | None = None # noqa
first_name: str | None = None # noqa
last_name: str | None = None # noqa
@classmethod
def from_workos_token(cls, token_data: dict[str, Any]):
"""Create WorkOSUser from decoded JWT token data"""
return cls(
id=token_data.get("sub", ""),
email=token_data.get("email"),
first_name=token_data.get("first_name"),
last_name=token_data.get("last_name"),
)
def _hydrate_user_from_workos_api(user: WorkOSUser) -> WorkOSUser:
"""
Populate missing user fields (like email) via the WorkOS User Management API.
Some WorkOS-issued access tokens omit profile fields. When email is missing,
we fetch the full user record using the user id from the token subject.
"""
if user.email:
return user
try:
workos_client = get_workos_client()
remote_user = workos_client.user_management.get_user(user.id)
user.email = getattr(remote_user, "email", None)
if not user.first_name:
user.first_name = getattr(remote_user, "first_name", None)
if not user.last_name:
user.last_name = getattr(remote_user, "last_name", None)
if not user.email:
logger.warning(f"No email returned from WorkOS for user {user.id}")
except Exception as exc:
logger.warning(
f"Unable to fetch WorkOS user details for {user.id}: {exc}",
exc_info=exc,
)
return user
async def get_current_workos_user(request: Request) -> WorkOSUser:
"""
Validate the user's WorkOS JWT token and return the user.
WorkOS tokens are JWTs that can be verified using the WorkOS client ID.
Args:
request: FastAPI request object
Returns:
WorkOSUser object with user information
Raises:
HTTPException: If token is missing, invalid, or expired
"""
auth_header = request.headers.get("Authorization")
if not auth_header:
raise HTTPException(status_code=401, detail="Missing authorization header")
if not auth_header.lower().startswith("bearer "):
raise HTTPException(
status_code=401,
detail="Invalid authorization header format. Expected 'Bearer <token>'",
)
try:
# Extract token
token = auth_header.split(" ", 1)[1]
# Check if we're in test mode (skip signature verification for tests)
# Detect test mode by checking if pytest is running
is_test_mode = "pytest" in sys.modules or "test" in sys.argv[0].lower()
# Determine whether the token declares an audience so we can decide
# whether to enforce audience verification (access tokens currently omit aud).
try:
unverified_claims = jwt.decode(
token,
options={
"verify_signature": False,
"verify_exp": False,
"verify_iss": False,
"verify_aud": False,
},
)
has_audience = "aud" in unverified_claims
except Exception:
# If we cannot read claims without verification, fall back to enforcing aud
has_audience = True
# Verify and decode the JWT token using WorkOS JWKS
try:
if is_test_mode:
# In test mode, decode without signature verification
# Tests use HS256 tokens with test secrets
decoded_token = jwt.decode(
token,
options={
"verify_signature": False,
"verify_exp": False,
"verify_iss": False,
"verify_aud": False,
},
)
logger.debug("Decoded test token without signature verification")
else:
# Production mode: verify signature using WorkOS JWKS
jwks_client = get_jwks_client()
# Get the signing key from JWKS
signing_key = jwks_client.get_signing_key_from_jwt(token)
# Decode and verify the JWT token with signature verification
decode_options = {
"verify_signature": True,
"verify_exp": True,
"verify_iss": True,
"verify_aud": has_audience,
}
if not has_audience:
logger.debug(
"WorkOS token missing 'aud' claim; skipping audience verification"
)
decoded_token = jwt.decode(
token,
signing_key.key,
algorithms=["RS256"], # WorkOS uses RS256 for JWT signing
issuer=WORKOS_ALLOWED_ISSUERS,
audience=WORKOS_AUDIENCE if has_audience else None,
options=decode_options,
)
except (DecodeError, InvalidTokenError, PyJWKClientError) as e:
logger.error(f"Invalid WorkOS token or JWKS lookup failed: {e}")
raise HTTPException(
status_code=401, detail="Invalid or expired token. Please log in again."
)
# Create user object from token data
user = WorkOSUser.from_workos_token(decoded_token)
if not user.id:
logger.error(f"Token missing required user id: {decoded_token}")
raise HTTPException(
status_code=401,
detail="Invalid token: missing required user id information",
)
# Fetch missing profile fields (e.g., email) from the WorkOS API if needed.
user = _hydrate_user_from_workos_api(user)
logger.debug(f"Successfully authenticated WorkOS user: {user.email}")
return user
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logger.exception(f"Unexpected error in WorkOS authentication: {e}")
raise HTTPException(status_code=500, detail="Internal server error")