Skip to content

Commit 7b46dd6

Browse files
committed
Use async RwLock for SASL client key cache
Prevents redundant concurrent hi() computations during pool startup.
1 parent 41b8688 commit 7b46dd6

3 files changed

Lines changed: 52 additions & 52 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlx-postgres/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ time = { workspace = true, optional = true }
5252
uuid = { workspace = true, optional = true }
5353

5454
# Misc
55+
async-lock = "3.4"
5556
atoi = "2.0"
5657
base64 = { version = "0.22.0", default-features = false, features = ["std"] }
5758
bitflags = { version = "2", default-features = false }

sqlx-postgres/src/connection/sasl.rs

Lines changed: 50 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use std::sync::{Arc, Mutex};
1+
use std::sync::Arc;
2+
3+
use async_lock::{RwLock, RwLockUpgradableReadGuard};
24

35
use crate::connection::stream::PgStream;
46
use crate::error::Error;
@@ -30,63 +32,77 @@ const NONCE_ATTR: &str = "r";
3032
/// that affect the HMAC result. The password is not included in the cache key because it can only
3133
/// change via `&mut self` on `PgConnectOptions`, which replaces the cache instance.
3234
///
35+
/// An async `RwLock` is used so that only one caller computes the key at a time; subsequent callers
36+
/// wait and then read the cached result.
37+
///
3338
/// According to [RFC-7677](https://datatracker.ietf.org/doc/html/rfc7677):
3439
///
3540
/// > This computational cost can be avoided by caching the ClientKey (assuming the Salt and hash
3641
/// > iteration-count is stable).
3742
#[derive(Debug, Clone)]
3843
pub struct ClientKeyCache {
39-
inner: Arc<Mutex<Option<CacheInner>>>,
44+
inner: Arc<RwLock<Option<CacheEntry>>>,
4045
}
4146

4247
#[derive(Debug)]
43-
struct CacheInner {
48+
struct CacheEntry {
49+
// Keys
4450
salt: Vec<u8>,
4551
iterations: u32,
52+
53+
// Values
4654
salted_password: [u8; 32],
4755
client_key: Hmac<Sha256>,
4856
}
4957

58+
impl CacheEntry {
59+
fn matches(&self, cont: &AuthenticationSaslContinue) -> bool {
60+
self.salt == cont.salt && self.iterations == cont.iterations
61+
}
62+
}
63+
5064
impl ClientKeyCache {
5165
pub fn new() -> Self {
5266
ClientKeyCache {
53-
inner: Arc::new(Mutex::new(None)),
67+
inner: Arc::new(RwLock::new(None)),
5468
}
5569
}
5670

57-
fn get(
71+
/// Returns the cached salted password and client key HMAC if the cache matches the given
72+
/// salt and iteration count. Otherwise, computes and caches them.
73+
async fn get_or_compute(
5874
&self,
75+
password: &str,
5976
cont: &AuthenticationSaslContinue,
60-
) -> Option<([u8; 32], Hmac<Sha256>)> {
61-
self.inner
62-
.lock()
63-
.expect("BUG: panicked while holding a lock")
64-
.as_ref()
65-
.and_then(|inner| {
66-
if inner.salt == cont.salt && inner.iterations == cont.iterations {
67-
Some((inner.salted_password, inner.client_key.clone()))
68-
} else {
69-
None
70-
}
71-
})
72-
}
77+
) -> Result<([u8; 32], Hmac<Sha256>), Error> {
78+
let guard = self.inner.upgradable_read().await;
7379

74-
fn set(
75-
&self,
76-
cont: &AuthenticationSaslContinue,
77-
salted_password: [u8; 32],
78-
client_key: Hmac<Sha256>,
79-
) {
80-
let mut inner = self
81-
.inner
82-
.lock()
83-
.expect("BUG: panicked while holding a lock");
84-
*inner = Some(CacheInner {
80+
if let Some(entry) = guard.as_ref().filter(|e| e.matches(cont)) {
81+
return Ok((entry.salted_password, entry.client_key.clone()));
82+
}
83+
84+
let mut guard = RwLockUpgradableReadGuard::upgrade(guard).await;
85+
86+
// Re-check after acquiring the write lock, in case another caller populated the cache.
87+
if let Some(entry) = guard.as_ref().filter(|e| e.matches(cont)) {
88+
return Ok((entry.salted_password, entry.client_key.clone()));
89+
}
90+
91+
// SaltedPassword := Hi(Normalize(password), salt, i)
92+
let salted_password = hi(password, &cont.salt, cont.iterations).await?;
93+
94+
// ClientKey := HMAC(SaltedPassword, "Client Key")
95+
let client_key =
96+
Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
97+
98+
*guard = Some(CacheEntry {
8599
salt: cont.salt.clone(),
86100
iterations: cont.iterations,
87101
salted_password,
88-
client_key,
102+
client_key: client_key.clone(),
89103
});
104+
105+
Ok((salted_password, client_key))
90106
}
91107
}
92108

@@ -160,28 +176,10 @@ pub(crate) async fn authenticate(
160176
}
161177
};
162178

163-
let (salted_password, mut mac) = {
164-
if let Some(cached) = options.sasl_client_key_cache.get(&cont) {
165-
cached
166-
} else {
167-
// SaltedPassword := Hi(Normalize(password), salt, i)
168-
let salted_password = hi(
169-
options.password.as_deref().unwrap_or_default(),
170-
&cont.salt,
171-
cont.iterations,
172-
)
173-
.await?;
174-
175-
// ClientKey := HMAC(SaltedPassword, "Client Key")
176-
let mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
177-
178-
options
179-
.sasl_client_key_cache
180-
.set(&cont, salted_password, mac.clone());
181-
182-
(salted_password, mac)
183-
}
184-
};
179+
let (salted_password, mut mac) = options
180+
.sasl_client_key_cache
181+
.get_or_compute(options.password.as_deref().unwrap_or_default(), &cont)
182+
.await?;
185183

186184
mac.update(b"Client Key");
187185

0 commit comments

Comments
 (0)