|
1 | | -use std::sync::{Arc, Mutex}; |
| 1 | +use std::sync::Arc; |
| 2 | + |
| 3 | +use async_lock::{RwLock, RwLockUpgradableReadGuard}; |
2 | 4 |
|
3 | 5 | use crate::connection::stream::PgStream; |
4 | 6 | use crate::error::Error; |
@@ -30,63 +32,77 @@ const NONCE_ATTR: &str = "r"; |
30 | 32 | /// that affect the HMAC result. The password is not included in the cache key because it can only |
31 | 33 | /// change via `&mut self` on `PgConnectOptions`, which replaces the cache instance. |
32 | 34 | /// |
| 35 | +/// An async `RwLock` is used so that only one caller computes the key at a time; subsequent callers |
| 36 | +/// wait and then read the cached result. |
| 37 | +/// |
33 | 38 | /// According to [RFC-7677](https://datatracker.ietf.org/doc/html/rfc7677): |
34 | 39 | /// |
35 | 40 | /// > This computational cost can be avoided by caching the ClientKey (assuming the Salt and hash |
36 | 41 | /// > iteration-count is stable). |
37 | 42 | #[derive(Debug, Clone)] |
38 | 43 | pub struct ClientKeyCache { |
39 | | - inner: Arc<Mutex<Option<CacheInner>>>, |
| 44 | + inner: Arc<RwLock<Option<CacheEntry>>>, |
40 | 45 | } |
41 | 46 |
|
42 | 47 | #[derive(Debug)] |
43 | | -struct CacheInner { |
| 48 | +struct CacheEntry { |
| 49 | + // Keys |
44 | 50 | salt: Vec<u8>, |
45 | 51 | iterations: u32, |
| 52 | + |
| 53 | + // Values |
46 | 54 | salted_password: [u8; 32], |
47 | 55 | client_key: Hmac<Sha256>, |
48 | 56 | } |
49 | 57 |
|
| 58 | +impl CacheEntry { |
| 59 | + fn matches(&self, cont: &AuthenticationSaslContinue) -> bool { |
| 60 | + self.salt == cont.salt && self.iterations == cont.iterations |
| 61 | + } |
| 62 | +} |
| 63 | + |
50 | 64 | impl ClientKeyCache { |
51 | 65 | pub fn new() -> Self { |
52 | 66 | ClientKeyCache { |
53 | | - inner: Arc::new(Mutex::new(None)), |
| 67 | + inner: Arc::new(RwLock::new(None)), |
54 | 68 | } |
55 | 69 | } |
56 | 70 |
|
57 | | - fn get( |
| 71 | + /// Returns the cached salted password and client key HMAC if the cache matches the given |
| 72 | + /// salt and iteration count. Otherwise, computes and caches them. |
| 73 | + async fn get_or_compute( |
58 | 74 | &self, |
| 75 | + password: &str, |
59 | 76 | cont: &AuthenticationSaslContinue, |
60 | | - ) -> Option<([u8; 32], Hmac<Sha256>)> { |
61 | | - self.inner |
62 | | - .lock() |
63 | | - .expect("BUG: panicked while holding a lock") |
64 | | - .as_ref() |
65 | | - .and_then(|inner| { |
66 | | - if inner.salt == cont.salt && inner.iterations == cont.iterations { |
67 | | - Some((inner.salted_password, inner.client_key.clone())) |
68 | | - } else { |
69 | | - None |
70 | | - } |
71 | | - }) |
72 | | - } |
| 77 | + ) -> Result<([u8; 32], Hmac<Sha256>), Error> { |
| 78 | + let guard = self.inner.upgradable_read().await; |
73 | 79 |
|
74 | | - fn set( |
75 | | - &self, |
76 | | - cont: &AuthenticationSaslContinue, |
77 | | - salted_password: [u8; 32], |
78 | | - client_key: Hmac<Sha256>, |
79 | | - ) { |
80 | | - let mut inner = self |
81 | | - .inner |
82 | | - .lock() |
83 | | - .expect("BUG: panicked while holding a lock"); |
84 | | - *inner = Some(CacheInner { |
| 80 | + if let Some(entry) = guard.as_ref().filter(|e| e.matches(cont)) { |
| 81 | + return Ok((entry.salted_password, entry.client_key.clone())); |
| 82 | + } |
| 83 | + |
| 84 | + let mut guard = RwLockUpgradableReadGuard::upgrade(guard).await; |
| 85 | + |
| 86 | + // Re-check after acquiring the write lock, in case another caller populated the cache. |
| 87 | + if let Some(entry) = guard.as_ref().filter(|e| e.matches(cont)) { |
| 88 | + return Ok((entry.salted_password, entry.client_key.clone())); |
| 89 | + } |
| 90 | + |
| 91 | + // SaltedPassword := Hi(Normalize(password), salt, i) |
| 92 | + let salted_password = hi(password, &cont.salt, cont.iterations).await?; |
| 93 | + |
| 94 | + // ClientKey := HMAC(SaltedPassword, "Client Key") |
| 95 | + let client_key = |
| 96 | + Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?; |
| 97 | + |
| 98 | + *guard = Some(CacheEntry { |
85 | 99 | salt: cont.salt.clone(), |
86 | 100 | iterations: cont.iterations, |
87 | 101 | salted_password, |
88 | | - client_key, |
| 102 | + client_key: client_key.clone(), |
89 | 103 | }); |
| 104 | + |
| 105 | + Ok((salted_password, client_key)) |
90 | 106 | } |
91 | 107 | } |
92 | 108 |
|
@@ -160,28 +176,10 @@ pub(crate) async fn authenticate( |
160 | 176 | } |
161 | 177 | }; |
162 | 178 |
|
163 | | - let (salted_password, mut mac) = { |
164 | | - if let Some(cached) = options.sasl_client_key_cache.get(&cont) { |
165 | | - cached |
166 | | - } else { |
167 | | - // SaltedPassword := Hi(Normalize(password), salt, i) |
168 | | - let salted_password = hi( |
169 | | - options.password.as_deref().unwrap_or_default(), |
170 | | - &cont.salt, |
171 | | - cont.iterations, |
172 | | - ) |
173 | | - .await?; |
174 | | - |
175 | | - // ClientKey := HMAC(SaltedPassword, "Client Key") |
176 | | - let mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?; |
177 | | - |
178 | | - options |
179 | | - .sasl_client_key_cache |
180 | | - .set(&cont, salted_password, mac.clone()); |
181 | | - |
182 | | - (salted_password, mac) |
183 | | - } |
184 | | - }; |
| 179 | + let (salted_password, mut mac) = options |
| 180 | + .sasl_client_key_cache |
| 181 | + .get_or_compute(options.password.as_deref().unwrap_or_default(), &cont) |
| 182 | + .await?; |
185 | 183 |
|
186 | 184 | mac.update(b"Client Key"); |
187 | 185 |
|
|
0 commit comments