2727import threading
2828import time
2929import uuid
30- from typing import Optional
30+ from typing import Any , Dict , Optional
3131from urllib .parse import urlparse , urlunparse
3232
3333from Cryptodome .PublicKey import RSA
@@ -58,7 +58,7 @@ class DPoPProofGenerator:
5858 - Keys are rotated periodically for better security
5959 """
6060
61- def __init__ (self , config : dict ) :
61+ def __init__ (self , config : Dict [ str , Any ]) -> None :
6262 """
6363 Initialize DPoP proof generator.
6464
@@ -67,12 +67,15 @@ def __init__(self, config: dict):
6767 - dpopKeyRotationInterval: Key rotation interval in seconds (default: 86400 / 24 hours)
6868 """
6969 self ._rsa_key : Optional [RSA .RsaKey ] = None
70- self ._public_jwk : Optional [dict ] = None
70+ self ._public_jwk : Optional [Dict [ str , str ] ] = None
7171 self ._key_created_at : Optional [float ] = None
7272 self ._rotation_interval : int = config .get ('dpopKeyRotationInterval' , 86400 ) # 24h default
7373 self ._nonce : Optional [str ] = None
74- self ._lock = threading .Lock () # Thread-safe lock for key operations
75- self ._active_requests = 0 # Track active requests for safe key rotation
74+
75+ # Use RLock for reentrant lock support
76+ # This allows the same thread to acquire the lock multiple times
77+ self ._lock : threading .RLock = threading .RLock ()
78+ self ._active_requests : int = 0 # Track active requests for safe key rotation
7679
7780 # Generate initial keys
7881 self ._rotate_keys_internal ()
@@ -85,7 +88,7 @@ def _rotate_keys_internal(self) -> None:
8588
8689 Generates a new RSA 2048-bit key pair and exports the public key as JWK.
8790 """
88- logger .info ("Generating new RSA 2048 -bit key pair for DPoP" )
91+ logger .info ("Generating new RSA 3072 -bit key pair for DPoP" )
8992 self ._rsa_key = RSA .generate (3072 )
9093 self ._public_jwk = self ._export_public_jwk ()
9194 self ._key_created_at = time .time ()
@@ -125,6 +128,8 @@ def generate_proof_jwt(
125128 Generate DPoP proof JWT per RFC 9449.
126129
127130 FIX #1: Strips query parameters and fragments from http_url per RFC 9449 Section 4.2.
131+ FIX #5 (IMPROVED): Thread-safe key access with proper lock protection to prevent
132+ race conditions during key rotation.
128133
129134 Args:
130135 http_method: HTTP method (GET, POST, etc.)
@@ -146,15 +151,24 @@ def generate_proof_jwt(
146151 ... access_token='eyJhbG...'
147152 ... )
148153 """
149- # FIX #5: Increment active request counter (thread-safe)
154+ # FIX #5 (IMPROVED): Acquire lock and capture key references atomically
155+ # This prevents race condition where rotation could happen between
156+ # counter increment and key usage
150157 with self ._lock :
151158 self ._active_requests += 1
152159
160+ # Capture key references while holding lock
161+ # This ensures we use consistent key state throughout JWT generation
162+ rsa_key = self ._rsa_key
163+ public_jwk = self ._public_jwk
164+ key_created_at = self ._key_created_at
165+ stored_nonce = self ._nonce
166+
153167 try :
154168 # Check if auto-rotation is needed (but don't rotate during active request)
155- if self . _should_rotate_keys () :
169+ if key_created_at and ( time . time () - key_created_at ) >= self . _rotation_interval :
156170 logger .warning (
157- f"DPoP keys are { time .time () - self . _key_created_at :.0f} s old, "
171+ f"DPoP keys are { time .time () - key_created_at :.0f} s old, "
158172 f"rotation recommended (interval: { self ._rotation_interval } s)"
159173 )
160174
@@ -187,7 +201,7 @@ def generate_proof_jwt(
187201 }
188202
189203 # Add optional nonce claim (use provided or stored)
190- effective_nonce = nonce or self . _nonce
204+ effective_nonce = nonce or stored_nonce
191205 if effective_nonce :
192206 claims ['nonce' ] = effective_nonce
193207 logger .debug (f"Added nonce to DPoP proof: { effective_nonce [:8 ]} ..." )
@@ -201,13 +215,13 @@ def generate_proof_jwt(
201215 headers = {
202216 'typ' : 'dpop+jwt' ,
203217 'alg' : 'RS256' ,
204- 'jwk' : self . _public_jwk
218+ 'jwk' : public_jwk
205219 }
206220
207- # Sign JWT with private key
221+ # Sign JWT with private key (using captured reference)
208222 token = jwt_encode (
209223 claims ,
210- self . _rsa_key .export_key (),
224+ rsa_key .export_key (),
211225 algorithm = 'RS256' ,
212226 headers = headers
213227 )
@@ -221,7 +235,7 @@ def generate_proof_jwt(
221235 return token
222236
223237 finally :
224- # FIX #5: Decrement active request counter (thread-safe)
238+ # FIX #5 (IMPROVED) : Decrement counter (thread-safe)
225239 with self ._lock :
226240 self ._active_requests -= 1
227241
@@ -260,7 +274,7 @@ def _compute_access_token_hash(self, access_token: str) -> str:
260274 logger .debug (f"Computed access token hash: { ath [:16 ]} ..." )
261275 return ath
262276
263- def _export_public_jwk (self ) -> dict :
277+ def _export_public_jwk (self ) -> Dict [ str , str ] :
264278 """
265279 Export ONLY public key components as JWK per RFC 7517.
266280
@@ -269,7 +283,7 @@ def _export_public_jwk(self) -> dict:
269283 and MUST NOT contain a private key.
270284
271285 Returns:
272- dict : JWK with only public components (kty, n, e)
286+ Dict[str, str] : JWK with only public components (kty, n, e)
273287
274288 Security Note:
275289 This method uses jwcrypto.export_public() to ensure only public
@@ -331,12 +345,12 @@ def get_nonce(self) -> Optional[str]:
331345 """
332346 return self ._nonce
333347
334- def get_public_jwk (self ) -> dict :
348+ def get_public_jwk (self ) -> Dict [ str , str ] :
335349 """
336350 Get public key in JWK format.
337351
338352 Returns:
339- Copy of the public JWK (kty, n, e)
353+ Dict[str, str]: Copy of the public JWK (kty, n, e)
340354 """
341355 return self ._public_jwk .copy () if self ._public_jwk else {}
342356
0 commit comments