-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathoauth2_auth.py
More file actions
268 lines (236 loc) · 9.82 KB
/
oauth2_auth.py
File metadata and controls
268 lines (236 loc) · 9.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
# AUTO-GENERATED FILE - DO NOT EDIT
# This file was automatically generated by the XDK build tool.
# Any manual changes will be overwritten on the next generation.
"""
Auto-generated OAuth2 PKCE authentication for the X API.
This module provides OAuth2 PKCE (Proof Key for Code Exchange) authentication
functionality for secure authorization flows. Includes code verifier generation,
token management, and automatic token refresh capabilities.
Generated automatically - do not edit manually.
"""
import secrets
import base64
import hashlib
import time
import urllib.parse
from typing import Dict, Optional, Any, Tuple, Union, List
import requests
from requests.auth import HTTPBasicAuth
from requests_oauthlib import OAuth2Session
class OAuth2PKCEAuth:
"""OAuth2 PKCE authentication for the X API."""
def __init__(
self,
base_url: str = "https://api.x.com",
authorization_base_url: str = "https://x.com/i",
client_id: str = None,
client_secret: str = None,
redirect_uri: str = None,
token: Dict[str, Any] = None,
scope: Union[str, List[str]] = None,
):
"""Initialize the OAuth2 PKCE authentication.
Args:
base_url: The base URL for the X API token endpoint (defaults to https://api.x.com).
authorization_base_url: The base URL for OAuth2 authorization (defaults to https://x.com/i).
client_id: The client ID for the X API.
client_secret: The client secret for the X API.
redirect_uri: The redirect URI for OAuth2 authorization.
token: An existing OAuth2 token dictionary (if available).
scope: Space-separated string or list of strings for OAuth2 authorization scopes.
"""
self.base_url = base_url
self.authorization_base_url = authorization_base_url.rstrip("/")
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.token = token
# Format scope: convert list to space-separated string if needed
if isinstance(scope, list):
self.scope = " ".join(scope)
else:
self.scope = scope
self.oauth2_session = None
self.code_verifier = None
self.code_challenge = None
# Set up OAuth2 session if we have a token
if token and client_id:
self._setup_oauth_session()
def _setup_oauth_session(self):
"""Set up the OAuth2 session with existing token."""
self.oauth2_session = OAuth2Session(
client_id=self.client_id,
token=self.token,
redirect_uri=self.redirect_uri,
scope=self.scope,
)
def _generate_code_verifier(self, length: int = 128) -> str:
"""Generate a code verifier for PKCE.
Args:
length: The length of the code verifier.
Returns:
str: The generated code verifier.
"""
code_verifier = secrets.token_urlsafe(96)[:length]
return code_verifier
def _generate_code_challenge(self, code_verifier: str) -> str:
"""Generate a code challenge from the code verifier.
Args:
code_verifier: The code verifier to generate a challenge from.
Returns:
str: The generated code challenge.
"""
code_challenge = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(code_challenge).decode().rstrip("=")
return code_challenge
def set_pkce_parameters(
self, code_verifier: str, code_challenge: Optional[str] = None
):
"""Manually set PKCE parameters.
Args:
code_verifier: The code verifier to use.
code_challenge: Optional code challenge (will be generated if not provided).
"""
self.code_verifier = code_verifier
if code_challenge:
self.code_challenge = code_challenge
else:
self.code_challenge = self._generate_code_challenge(code_verifier)
def get_authorization_url(self, state: Optional[str] = None) -> str:
"""Get the authorization URL for the OAuth2 PKCE flow.
Args:
state: Optional state parameter for security.
Returns:
str: The authorization URL.
"""
# Auto-generate PKCE parameters if not already set
if not self.code_verifier or not self.code_challenge:
self.code_verifier = self._generate_code_verifier()
self.code_challenge = self._generate_code_challenge(self.code_verifier)
self.oauth2_session = OAuth2Session(
client_id=self.client_id,
redirect_uri=self.redirect_uri,
scope=self.scope,
state=state,
)
# Use authorization_base_url for authorization endpoint
# base_url is used for API token endpoints
auth_url, generated_state = self.oauth2_session.authorization_url(
f"{self.authorization_base_url}/oauth2/authorize",
code_challenge=self.code_challenge,
code_challenge_method="S256",
)
return auth_url
def exchange_code(
self, code: str, code_verifier: Optional[str] = None
) -> Dict[str, Any]:
"""Exchange authorization code for tokens (matches TypeScript API).
Args:
code: The authorization code from the callback.
code_verifier: Optional code verifier (uses stored verifier if not provided).
Returns:
Dict[str, Any]: The token dictionary
"""
if not code_verifier:
code_verifier = self.code_verifier
if not code_verifier:
raise ValueError(
"Code verifier is required. Call get_authorization_url() or set_pkce_parameters() first."
)
# Build the token exchange request manually to match TypeScript implementation
params = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
"code_verifier": code_verifier,
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
# Add Basic Auth header if client secret is provided (matches TypeScript)
auth = None
if self.client_secret:
auth = HTTPBasicAuth(self.client_id, self.client_secret)
else:
# Only add client_id to body if no client_secret (public client)
params["client_id"] = self.client_id
response = requests.post(
f"{self.base_url}/2/oauth2/token", data=params, headers=headers, auth=auth
)
if not response.ok:
try:
error_data = response.json()
except:
error_data = response.text
raise ValueError(
f"HTTP error! status: {response.status_code}, body: {error_data}"
)
data = response.json()
self.token = {
"access_token": data.get("access_token"),
"token_type": data.get("token_type"),
"expires_in": data.get("expires_in"),
"refresh_token": data.get("refresh_token"),
"scope": data.get("scope"),
}
# Calculate expires_at if expires_in is provided
if "expires_in" in data and data["expires_in"]:
self.token["expires_at"] = time.time() + data["expires_in"]
# Set up OAuth2 session with the new token
if self.client_id:
self._setup_oauth_session()
return self.token
def fetch_token(self, authorization_response: str) -> Dict[str, Any]:
"""Fetch token using authorization response URL (legacy method, uses exchange_code internally).
Args:
authorization_response: The full callback URL received after authorization
Returns:
Dict[str, Any]: The token dictionary
"""
# Parse the authorization code from the callback URL
parsed = urllib.parse.urlparse(authorization_response)
query_params = urllib.parse.parse_qs(parsed.query)
if "code" not in query_params:
raise ValueError("No authorization code found in callback URL")
code = query_params["code"][0]
return self.exchange_code(code)
def refresh_token(self) -> Dict[str, Any]:
"""Refresh the access token.
Returns:
Dict[str, Any]: The refreshed token dictionary
"""
if not self.oauth2_session or not self.token:
raise ValueError("No token to refresh")
refresh_url = f"{self.base_url}/2/oauth2/token"
self.token = self.oauth2_session.refresh_token(
refresh_url, auth=HTTPBasicAuth(self.client_id, self.client_secret)
)
return self.token
def get_code_verifier(self) -> Optional[str]:
"""Get the current code verifier (for PKCE).
Returns:
Optional[str]: The current code verifier, or None if not set.
"""
return self.code_verifier
def get_code_challenge(self) -> Optional[str]:
"""Get the current code challenge (for PKCE).
Returns:
Optional[str]: The current code challenge, or None if not set.
"""
return self.code_challenge
@property
def access_token(self) -> Optional[str]:
"""Get the current access token.
Returns:
Optional[str]: The current access token, or None if no token exists.
"""
if self.token:
return self.token.get("access_token")
return None
def is_token_expired(self) -> bool:
"""Check if the token is expired.
Returns:
bool: True if the token is expired, False otherwise.
"""
if not self.token or "expires_at" not in self.token:
return True
# Add a 10-second buffer to avoid edge cases
return time.time() > (self.token["expires_at"] - 10)