Skip to content

Commit b676d11

Browse files
committed
fix: new pooled connections created with stale password
1 parent 305e552 commit b676d11

4 files changed

Lines changed: 178 additions & 16 deletions

File tree

aws_advanced_python_wrapper/sql_alchemy_connection_provider.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ class SqlAlchemyPooledConnectionProvider(ConnectionProvider, CanReleaseResources
5151
"weighted_random": WeightedRandomHostSelector(),
5252
"highest_weight": HighestWeightHostSelector()}
5353
_rds_utils: ClassVar[RdsUtils] = RdsUtils()
54-
_database_pools: ClassVar[SlidingExpirationCache[PoolKey, QueuePool]] = SlidingExpirationCache(
55-
should_dispose_func=lambda queue_pool: queue_pool.checkedout() == 0,
56-
item_disposal_func=lambda queue_pool: queue_pool.dispose()
54+
_database_pools: ClassVar[SlidingExpirationCache[PoolKey, Tuple[QueuePool, Properties]]] = SlidingExpirationCache(
55+
should_dispose_func=lambda pool_pair: pool_pair[0].checkedout() == 0,
56+
item_disposal_func=lambda pool_pair: pool_pair[0].dispose()
5757
)
5858

5959
def __init__(
@@ -119,7 +119,8 @@ def _num_connections(self, host_info: HostInfo) -> int:
119119
num_connections = 0
120120
for pool_key, cache_item in SqlAlchemyPooledConnectionProvider._database_pools.items():
121121
if pool_key.url == host_info.url:
122-
num_connections += cache_item.item.checkedout()
122+
queue_pool, _ = cache_item.item
123+
num_connections += queue_pool.checkedout()
123124
return num_connections
124125

125126
def connect(
@@ -129,15 +130,22 @@ def connect(
129130
database_dialect: DatabaseDialect,
130131
host_info: HostInfo,
131132
props: Properties):
132-
queue_pool: Optional[QueuePool] = SqlAlchemyPooledConnectionProvider._database_pools.compute_if_absent(
133+
db_pool: Optional[Tuple[QueuePool, Properties]] = SqlAlchemyPooledConnectionProvider._database_pools.compute_if_absent(
133134
PoolKey(host_info.url, self._get_extra_key(host_info, props)),
134135
lambda _: self._create_pool(target_func, driver_dialect, database_dialect, host_info, props),
135136
SqlAlchemyPooledConnectionProvider._POOL_EXPIRATION_CHECK_NS
136137
)
137138

138-
if queue_pool is None:
139+
if db_pool is None:
139140
raise AwsWrapperError(Messages.get_formatted("SqlAlchemyPooledConnectionProvider.PoolNone", host_info.url))
140141

142+
queue_pool, creator_props = db_pool
143+
144+
# Update the password in the creator's captured properties so new pooled connections use the latest credentials
145+
password = WrapperProperties.PASSWORD.get(props)
146+
if password is not None:
147+
creator_props[WrapperProperties.PASSWORD.name] = password
148+
141149
return queue_pool.connect()
142150

143151
# The pool key should always be retrieved using this method, because the username
@@ -163,7 +171,7 @@ def _create_pool(
163171
prepared_properties = driver_dialect.prepare_connect_info(host_info, props)
164172
database_dialect.prepare_conn_props(prepared_properties)
165173
kwargs["creator"] = self._get_connection_func(target_func, prepared_properties)
166-
return self._create_sql_alchemy_pool(**kwargs)
174+
return self._create_sql_alchemy_pool(**kwargs), prepared_properties
167175

168176
def _get_connection_func(self, target_connect_func: Callable, props: Properties):
169177
return lambda: target_connect_func(**props)
@@ -174,7 +182,8 @@ def _create_sql_alchemy_pool(self, **kwargs):
174182
def release_resources(self):
175183
for _, cache_item in SqlAlchemyPooledConnectionProvider._database_pools.items():
176184
try:
177-
cache_item.item.dispose()
185+
queue_pool, _ = cache_item.item
186+
queue_pool.dispose()
178187
except Exception:
179188
# Swallow exception, connections may already be dead
180189
pass

aws_advanced_python_wrapper/utils/pg_exception_handler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ def _is_network_error(self, error: Optional[BaseException], sql_state: Optional[
7575
return False
7676
# Check the error message if this is a generic error
7777
error_msg: str = error.args[0]
78-
return any(error_msg.startswith(msg) for msg in self._NETWORK_ERROR_MESSAGES)
78+
is_network_error: bool = any(error_msg.startswith(msg) for msg in self._NETWORK_ERROR_MESSAGES)
79+
# PAM errors may be nested in the connection error, double check here to avoid false positives
80+
# Example nested error: 'connection failed: ...: FATAL: password authentication failed for user'
81+
return is_network_error and not any(msg in error_msg for msg in self._ACCESS_ERROR_MESSAGES)
7982

8083
return False
8184

tests/unit/test_exception_handling.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,72 @@ def test_is_login_exception_with_nested_non_login_error_pg(pg_handler):
162162
assert pg_handler.is_login_exception(error=wrapper_error) is False
163163

164164

165+
def test_nested_pam_error_in_connection_failed_is_not_network_exception_pg(pg_handler):
166+
error_msg = (
167+
'connection failed: connection to server at "", port 5432 failed: '
168+
'FATAL: PAM authentication failed for user ""'
169+
)
170+
wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg))
171+
172+
assert pg_handler.is_network_exception(error=wrapper_error) is False
173+
174+
175+
def test_nested_password_auth_error_in_connection_failed_is_not_network_exception_pg(pg_handler):
176+
error_msg = (
177+
'connection failed: connection to server at "", port 5432 failed: '
178+
'FATAL: password authentication failed for user ""'
179+
)
180+
wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg))
181+
182+
assert pg_handler.is_network_exception(error=wrapper_error) is False
183+
184+
185+
def test_nested_pam_error_deeply_wrapped_is_not_network_exception_pg(pg_handler):
186+
error_msg = (
187+
'connection failed: connection to server at "", port 5432 failed: '
188+
'FATAL: password authentication failed for user ""'
189+
)
190+
wrapper_error = AwsWrapperError(
191+
"[IamAuthPlugin] Error occurred while opening a connection",
192+
AwsWrapperError("Inner wrapper", OperationalError(error_msg)))
193+
194+
assert pg_handler.is_network_exception(error=wrapper_error) is False
195+
196+
197+
def test_nested_pam_error_is_login_exception_pg(pg_handler):
198+
error_msg = (
199+
'connection failed: connection to server at "", port 5432 failed: '
200+
'FATAL: PAM authentication failed for user ""'
201+
)
202+
wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg))
203+
204+
assert pg_handler.is_login_exception(error=wrapper_error) is True
205+
206+
207+
def test_nested_password_auth_error_is_login_exception_pg(pg_handler):
208+
error_msg = (
209+
'connection failed: connection to server at "", port 5432 failed: '
210+
'FATAL: password authentication failed for user ""'
211+
)
212+
wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg))
213+
214+
assert pg_handler.is_login_exception(error=wrapper_error) is True
215+
216+
217+
def test_pure_connection_failed_is_network_exception_pg(pg_handler):
218+
error_msg = 'connection failed: connection to server at "", port 5432 failed: Connection refused'
219+
wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg))
220+
221+
assert pg_handler.is_network_exception(error=wrapper_error) is True
222+
223+
224+
def test_pure_connection_failed_is_not_login_exception_pg(pg_handler):
225+
error_msg = 'connection failed: connection to server at "", port 5432 failed: Connection refused'
226+
wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg))
227+
228+
assert pg_handler.is_login_exception(error=wrapper_error) is False
229+
230+
165231
def test_is_read_only_exception_with_nested_aws_wrapper_error_mysql(mysql_handler):
166232
class MockReadOnlyError(Exception):
167233
def __init__(self):

tests/unit/test_sql_alchemy_pooled_connection_provider.py

Lines changed: 91 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ def clear_cache():
6363
SqlAlchemyPooledConnectionProvider._database_pools.clear()
6464

6565

66+
@pytest.fixture
67+
def mock_dialects(mocker):
68+
mock_driver_dialect = mocker.MagicMock()
69+
mock_database_dialect = mocker.MagicMock()
70+
mock_driver_dialect.prepare_connect_info.side_effect = lambda host, props: Properties(props.copy())
71+
mock_database_dialect.prepare_conn_props.return_value = None
72+
return mock_driver_dialect, mock_database_dialect
73+
74+
6675
def test_connect__default_mapping__default_pool_configuration(provider, host_info, mocker, mock_conn, mock_pool):
6776
expected_urls = {host_info.url}
6877
expected_keys = [PoolKey(host_info.url, "user1")]
@@ -100,6 +109,69 @@ def test_connect__custom_configuration_and_mapping(host_info, mocker, mock_conn,
100109
mock_pool_initializer_func.assert_called_with(creator=mock_pool_connection_func, pool_size=10)
101110

102111

112+
def test_connect__updates_password_in_cached_pool_creator_props(host_info, mocker, mock_dialects):
113+
captured_password: list = []
114+
props = Properties({WrapperProperties.USER.name: "user1", WrapperProperties.PASSWORD.name: "TOKEN_1"})
115+
116+
def fake_target_connect(**kwargs):
117+
captured_password.append(kwargs.get("password"))
118+
return mocker.MagicMock(spec=psycopg.Connection)
119+
120+
provider = SqlAlchemyPooledConnectionProvider(
121+
pool_configurator=lambda _, __: {"pool_size": 0, "max_overflow": 2}
122+
)
123+
mock_driver_dialect, mock_database_dialect = mock_dialects
124+
125+
# Create a cached pool
126+
conn_1 = provider.connect(fake_target_connect, mock_driver_dialect, mock_database_dialect, host_info, props)
127+
assert conn_1 is not None
128+
129+
# Rotate password
130+
props[WrapperProperties.PASSWORD.name] = "TOKEN_2"
131+
132+
conn_2 = provider.connect(fake_target_connect, mock_driver_dialect, mock_database_dialect, host_info, props)
133+
assert conn_2 is not None
134+
135+
assert captured_password == ["TOKEN_1", "TOKEN_2"]
136+
137+
138+
def test_connect__password_update_different_pool_keys(host_info, mocker, mock_dialects):
139+
captured_password_user1: list = []
140+
captured_password_user2: list = []
141+
props_user1 = Properties({WrapperProperties.USER.name: "user1", WrapperProperties.PASSWORD.name: "TOKEN_1"})
142+
props_user2 = Properties({WrapperProperties.USER.name: "user2", WrapperProperties.PASSWORD.name: "TOKEN_2"})
143+
144+
def fake_target_connect(**kwargs):
145+
user = kwargs.get("user")
146+
pwd = kwargs.get("password")
147+
if user == "user1":
148+
captured_password_user1.append(pwd)
149+
else:
150+
captured_password_user2.append(pwd)
151+
return mocker.MagicMock(spec=psycopg.Connection)
152+
153+
provider = SqlAlchemyPooledConnectionProvider(
154+
pool_configurator=lambda _, __: {"pool_size": 0, "max_overflow": 3}
155+
)
156+
mock_driver_dialect, mock_database_dialect = mock_dialects
157+
158+
conn1 = provider.connect(fake_target_connect, mock_driver_dialect, mock_database_dialect, host_info, props_user1)
159+
conn2 = provider.connect(fake_target_connect, mock_driver_dialect, mock_database_dialect, host_info, props_user2)
160+
assert conn1 is not None
161+
assert conn2 is not None
162+
163+
assert captured_password_user1 == ["TOKEN_1"]
164+
assert captured_password_user2 == ["TOKEN_2"]
165+
166+
# Rotate password for user 1
167+
props_user1[WrapperProperties.PASSWORD.name] = "TOKEN_3"
168+
conn3 = provider.connect(fake_target_connect, mock_driver_dialect, mock_database_dialect, host_info, props_user1)
169+
assert conn3 is not None
170+
171+
assert captured_password_user1 == ["TOKEN_1", "TOKEN_3"]
172+
assert captured_password_user2 == ["TOKEN_2"]
173+
174+
103175
def test_accepts_host_info(provider):
104176
instance_url = "instance-1.XYZ.us-east-2.rds.amazonaws.com"
105177
instance_host_info = HostInfo(instance_url)
@@ -115,18 +187,24 @@ def test_least_connections_strategy(provider, mock_pool):
115187
writer = HostInfo("writer.XYZ.us-east-2.rds.amazonaws.com")
116188
reader_1 = HostInfo("reader-1.XYZ.us-east-2.rds.amazonaws.com", role=HostRole.READER)
117189
reader_2 = HostInfo("reader-2.XYZ.us-east-2.rds.amazonaws.com", role=HostRole.READER)
118-
hosts = [writer, reader_1, reader_2]
190+
hosts = (writer, reader_1, reader_2)
119191
props = Properties({WrapperProperties.USER.name: "user1", WrapperProperties.PASSWORD.name: "password"})
120192

121193
# Create cache with 1 pool to reader_url_1_connection and 2 pools to reader_url_2_connections.
122194
# Each pool holds 1 connection.
123195
test_database_pools = SlidingExpirationCache()
124196
test_database_pools.compute_if_absent(
125-
PoolKey(reader_1.url, "user1"), lambda _: mock_pool, 10 * 60_000_000_000)
197+
PoolKey(reader_1.url, "user1"),
198+
lambda _: (mock_pool, Properties()),
199+
10 * 60_000_000_000)
126200
test_database_pools.compute_if_absent(
127-
PoolKey(reader_2.url, "user1"), lambda _: mock_pool, 10 * 60_000_000_000)
201+
PoolKey(reader_2.url, "user1"),
202+
lambda _: (mock_pool, Properties()),
203+
10 * 60_000_000_000)
128204
test_database_pools.compute_if_absent(
129-
PoolKey(reader_2.url, "user2"), lambda _: mock_pool, 10 * 60_000_000_000)
205+
PoolKey(reader_2.url, "user2"),
206+
lambda _: (mock_pool, Properties()),
207+
10 * 60_000_000_000)
130208

131209
result = provider.get_host_info_by_strategy(hosts, HostRole.READER, "least_connections", props)
132210
assert reader_1 == result
@@ -135,15 +213,21 @@ def test_least_connections_strategy(provider, mock_pool):
135213
def test_least_connections_strategy__no_hosts_matching_role(provider):
136214
props = Properties()
137215
with pytest.raises(AwsWrapperError):
138-
provider.get_host_info_by_strategy([HostInfo("writer")], HostRole.READER, "least_connections", props)
216+
provider.get_host_info_by_strategy((HostInfo("writer"),), HostRole.READER, "least_connections", props)
139217

140218

141219
def test_release_resources(provider, mocker):
142220
pool1 = mocker.MagicMock()
143221
pool2 = mocker.MagicMock()
144222
test_database_pools = SlidingExpirationCache()
145-
test_database_pools.compute_if_absent(PoolKey("url1", "user1"), lambda _: pool1, 60_000_000_000)
146-
test_database_pools.compute_if_absent(PoolKey("url1", "user2"), lambda _: pool2, 60_000_000_000)
223+
test_database_pools.compute_if_absent(
224+
PoolKey("url1", "user1"),
225+
lambda _: (pool1, Properties()),
226+
60_000_000_000)
227+
test_database_pools.compute_if_absent(
228+
PoolKey("url1", "user2"),
229+
lambda _: (pool2, Properties()),
230+
60_000_000_000)
147231
SqlAlchemyPooledConnectionProvider._database_pools = test_database_pools
148232

149233
provider.release_resources()

0 commit comments

Comments
 (0)