forked from launchbadge/sqlx
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdsql.rs
More file actions
170 lines (148 loc) · 5.89 KB
/
dsql.rs
File metadata and controls
170 lines (148 loc) · 5.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
//! Aurora DSQL requires an IAM token in place of a password. Tokens are
//! generated by the AWS SDK using your AWS credentials.
use std::{
borrow::Cow,
fmt,
sync::{Arc, RwLock},
time::Duration,
};
use aws_config::{BehaviorVersion, SdkConfig};
use aws_sdk_dsql::{
auth_token::{AuthToken, AuthTokenGenerator, Config},
error::BoxError,
};
use sqlx_postgres::PasswordProvider;
use tokio::{task::JoinHandle, time::sleep};
/// A builder type to get you build a customized [`DsqlIamProvider`], in case
/// the AWS SDK defaults aren't what you're looking for.
///
/// If you're happy with the AWS SDK defaults, prefer using
/// [`DsqlIamProvider::new`].
///
///
/// ```ignore
/// use sqlx_aws::iam::dsql::*;
///
/// let b = DsqlIamProviderBuilder::defaults().await;
/// let my_config = Config::builder().hostname("...").build()?;
/// let provider = b.with_generator_config(my_config).await?;
/// ```
pub struct DsqlIamProviderBuilder {
cfg: SdkConfig,
is_admin: bool,
}
impl DsqlIamProviderBuilder {
/// A new builder. The AWS SDK is automatically configured.
pub async fn defaults() -> Self {
let cfg = aws_config::load_defaults(BehaviorVersion::latest()).await;
Self::new_with_sdk_cfg(cfg)
}
/// A new builder with custom SDK config.
pub fn new_with_sdk_cfg(cfg: SdkConfig) -> Self {
Self {
cfg,
is_admin: false,
}
}
/// Build a provider with the given [`auth_token::Config`].
pub async fn with_generator_config(self, config: Config) -> Result<DsqlIamProvider, BoxError> {
let DsqlIamProviderBuilder { cfg, is_admin } = self;
// This default value is hardcoded in the AuthTokenGenerator. There is
// no way to share the value.
let expires_in = config.expires_in().unwrap_or(900);
// Token generation is fast (because it is a local operation). However,
// there is some coordination involved (such as loading AWS credentials,
// or tokio scheduling). We want to avoid ever having stale tokens, and so schedule refreshes slightly ahead of expiry.
let refresh_interval = Duration::from_secs(if expires_in > 60 {
expires_in - 60
} else {
expires_in
});
let generator = AuthTokenGenerator::new(config);
// Boostrap: try once. This allows for failing fast for the case where
// things haven't been correctly configured.
let auth_token = match is_admin {
true => generator.db_connect_admin_auth_token(&cfg).await,
false => generator.db_connect_auth_token(&cfg).await,
}?;
let token = Arc::new(RwLock::new(Ok(auth_token)));
let _token = token.clone();
let task = tokio::spawn(async move {
sleep(refresh_interval).await;
loop {
let res = match is_admin {
true => generator.db_connect_admin_auth_token(&cfg).await,
false => generator.db_connect_auth_token(&cfg).await,
};
match res {
Ok(auth_token) => {
*_token.write().expect("never poisoned") = Ok(auth_token);
sleep(refresh_interval).await;
}
// XXX: In theory, this should almost never happen, because
// we did a boostrap token generation, which should catch
// nearly all errors. However, it is possible that the
// underlying credential provider has failed in some way.
Err(err) => {
// Refreshes are eager, which means it may be possible
// that we're about to replace perfectly good token with
// an error. It doesn't seem worthwhile to guard against
// that, since tokens are short lived and are likely to
// expire shortly anyways.
*_token.write().expect("never poisoned") = Err(err);
// sleep an arbitrary amount of time to prevent busy
// loops, but not so long that we don't try again (if
// the underlying error has been resolved).
sleep(Duration::from_secs(1)).await;
}
}
}
});
Ok(DsqlIamProvider { token, task })
}
}
/// A sqlx [`PasswordProvider`] that automatically manages IAM tokens.
///
/// ```ignore
/// use sqlx_postgres::PgConnectOptions;
/// use sqlx_aws::iam::dsql::*;
///
/// let provider = DsqlIamProvider::new("peccy.dsql.us-east-1.on.aws").await?;
/// let opts = PgConnectOptions::new_without_pgpass()
/// .password(provider);
/// ```
pub struct DsqlIamProvider {
token: Arc<RwLock<Result<AuthToken, BoxError>>>,
task: JoinHandle<()>,
}
impl Drop for DsqlIamProvider {
fn drop(&mut self) {
self.task.abort();
}
}
impl DsqlIamProvider {
pub async fn new(hostname: impl Into<String>) -> Result<Self, BoxError> {
let builder = DsqlIamProviderBuilder::defaults().await;
let config = Config::builder()
.hostname(hostname)
.build()
.expect("hostname was provided");
builder.with_generator_config(config).await
}
}
impl PasswordProvider for DsqlIamProvider {
fn password<'a>(&'a self) -> Result<Cow<'a, str>, sqlx_core::error::BoxDynError> {
match &*self.token.read().expect("never poisoned") {
Ok(auth_token) => Ok(Cow::Owned(auth_token.as_str().to_string())),
Err(err) => Err(Box::new(RefreshError(format!("{err}")))),
}
}
}
#[derive(Debug)]
pub struct RefreshError(String);
impl fmt::Display for RefreshError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "unable to refresh auth token: {}", self.0)
}
}
impl std::error::Error for RefreshError {}