Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions aws_advanced_python_wrapper/sql_alchemy_connection_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ class SqlAlchemyPooledConnectionProvider(ConnectionProvider, CanReleaseResources
"weighted_random": WeightedRandomHostSelector(),
"highest_weight": HighestWeightHostSelector()}
_rds_utils: ClassVar[RdsUtils] = RdsUtils()
_database_pools: ClassVar[SlidingExpirationCache[PoolKey, QueuePool]] = SlidingExpirationCache(
should_dispose_func=lambda queue_pool: queue_pool.checkedout() == 0,
item_disposal_func=lambda queue_pool: queue_pool.dispose()
_database_pools: ClassVar[SlidingExpirationCache[PoolKey, Tuple[QueuePool, Properties]]] = SlidingExpirationCache(
should_dispose_func=lambda pool_pair: pool_pair[0].checkedout() == 0,
item_disposal_func=lambda pool_pair: pool_pair[0].dispose()
)

def __init__(
Expand Down Expand Up @@ -119,7 +119,8 @@ def _num_connections(self, host_info: HostInfo) -> int:
num_connections = 0
for pool_key, cache_item in SqlAlchemyPooledConnectionProvider._database_pools.items():
if pool_key.url == host_info.url:
num_connections += cache_item.item.checkedout()
queue_pool, _ = cache_item.item
num_connections += queue_pool.checkedout()
return num_connections

def connect(
Expand All @@ -129,15 +130,22 @@ def connect(
database_dialect: DatabaseDialect,
host_info: HostInfo,
props: Properties):
queue_pool: Optional[QueuePool] = SqlAlchemyPooledConnectionProvider._database_pools.compute_if_absent(
db_pool: Optional[Tuple[QueuePool, Properties]] = SqlAlchemyPooledConnectionProvider._database_pools.compute_if_absent(
PoolKey(host_info.url, self._get_extra_key(host_info, props)),
lambda _: self._create_pool(target_func, driver_dialect, database_dialect, host_info, props),
SqlAlchemyPooledConnectionProvider._POOL_EXPIRATION_CHECK_NS
)

if queue_pool is None:
if db_pool is None:
raise AwsWrapperError(Messages.get_formatted("SqlAlchemyPooledConnectionProvider.PoolNone", host_info.url))

queue_pool, creator_props = db_pool

# Update the password in the creator's captured properties so new pooled connections use the latest credentials
password = WrapperProperties.PASSWORD.get(props)
if password is not None:
creator_props[WrapperProperties.PASSWORD.name] = password

return queue_pool.connect()

# The pool key should always be retrieved using this method, because the username
Expand All @@ -163,7 +171,7 @@ def _create_pool(
prepared_properties = driver_dialect.prepare_connect_info(host_info, props)
database_dialect.prepare_conn_props(prepared_properties)
kwargs["creator"] = self._get_connection_func(target_func, prepared_properties)
return self._create_sql_alchemy_pool(**kwargs)
return self._create_sql_alchemy_pool(**kwargs), prepared_properties

def _get_connection_func(self, target_connect_func: Callable, props: Properties):
return lambda: target_connect_func(**props)
Expand All @@ -174,7 +182,8 @@ def _create_sql_alchemy_pool(self, **kwargs):
def release_resources(self):
for _, cache_item in SqlAlchemyPooledConnectionProvider._database_pools.items():
try:
cache_item.item.dispose()
queue_pool, _ = cache_item.item
queue_pool.dispose()
except Exception:
# Swallow exception, connections may already be dead
pass
Expand Down
5 changes: 4 additions & 1 deletion aws_advanced_python_wrapper/utils/pg_exception_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def _is_network_error(self, error: Optional[BaseException], sql_state: Optional[
return False
# Check the error message if this is a generic error
error_msg: str = error.args[0]
return any(error_msg.startswith(msg) for msg in self._NETWORK_ERROR_MESSAGES)
is_network_error: bool = any(error_msg.startswith(msg) for msg in self._NETWORK_ERROR_MESSAGES)
# PAM errors may be nested in the connection error, double check here to avoid false positives
# Example nested error: 'connection failed: ...: FATAL: password authentication failed for user'
return is_network_error and not any(msg in error_msg for msg in self._ACCESS_ERROR_MESSAGES)

return False

Expand Down
66 changes: 66 additions & 0 deletions tests/unit/test_exception_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,72 @@ def test_is_login_exception_with_nested_non_login_error_pg(pg_handler):
assert pg_handler.is_login_exception(error=wrapper_error) is False


def test_nested_pam_error_in_connection_failed_is_not_network_exception_pg(pg_handler):
error_msg = (
'connection failed: connection to server at "", port 5432 failed: '
'FATAL: PAM authentication failed for user ""'
)
wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg))

assert pg_handler.is_network_exception(error=wrapper_error) is False


def test_nested_password_auth_error_in_connection_failed_is_not_network_exception_pg(pg_handler):
error_msg = (
'connection failed: connection to server at "", port 5432 failed: '
'FATAL: password authentication failed for user ""'
)
wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg))

assert pg_handler.is_network_exception(error=wrapper_error) is False


def test_nested_pam_error_deeply_wrapped_is_not_network_exception_pg(pg_handler):
error_msg = (
'connection failed: connection to server at "", port 5432 failed: '
'FATAL: password authentication failed for user ""'
)
wrapper_error = AwsWrapperError(
"[IamAuthPlugin] Error occurred while opening a connection",
AwsWrapperError("Inner wrapper", OperationalError(error_msg)))

assert pg_handler.is_network_exception(error=wrapper_error) is False


def test_nested_pam_error_is_login_exception_pg(pg_handler):
error_msg = (
'connection failed: connection to server at "", port 5432 failed: '
'FATAL: PAM authentication failed for user ""'
)
wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg))

assert pg_handler.is_login_exception(error=wrapper_error) is True


def test_nested_password_auth_error_is_login_exception_pg(pg_handler):
error_msg = (
'connection failed: connection to server at "", port 5432 failed: '
'FATAL: password authentication failed for user ""'
)
wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg))

assert pg_handler.is_login_exception(error=wrapper_error) is True


def test_pure_connection_failed_is_network_exception_pg(pg_handler):
error_msg = 'connection failed: connection to server at "", port 5432 failed: Connection refused'
wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg))

assert pg_handler.is_network_exception(error=wrapper_error) is True


def test_pure_connection_failed_is_not_login_exception_pg(pg_handler):
error_msg = 'connection failed: connection to server at "", port 5432 failed: Connection refused'
wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg))

assert pg_handler.is_login_exception(error=wrapper_error) is False


def test_is_read_only_exception_with_nested_aws_wrapper_error_mysql(mysql_handler):
class MockReadOnlyError(Exception):
def __init__(self):
Expand Down
98 changes: 91 additions & 7 deletions tests/unit/test_sql_alchemy_pooled_connection_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ def clear_cache():
SqlAlchemyPooledConnectionProvider._database_pools.clear()


@pytest.fixture
def mock_dialects(mocker):
mock_driver_dialect = mocker.MagicMock()
mock_database_dialect = mocker.MagicMock()
mock_driver_dialect.prepare_connect_info.side_effect = lambda host, props: Properties(props.copy())
mock_database_dialect.prepare_conn_props.return_value = None
return mock_driver_dialect, mock_database_dialect


def test_connect__default_mapping__default_pool_configuration(provider, host_info, mocker, mock_conn, mock_pool):
expected_urls = {host_info.url}
expected_keys = [PoolKey(host_info.url, "user1")]
Expand Down Expand Up @@ -100,6 +109,69 @@ def test_connect__custom_configuration_and_mapping(host_info, mocker, mock_conn,
mock_pool_initializer_func.assert_called_with(creator=mock_pool_connection_func, pool_size=10)


def test_connect__updates_password_in_cached_pool_creator_props(host_info, mocker, mock_dialects):
captured_password: list = []
props = Properties({WrapperProperties.USER.name: "user1", WrapperProperties.PASSWORD.name: "TOKEN_1"})

def fake_target_connect(**kwargs):
captured_password.append(kwargs.get("password"))
return mocker.MagicMock(spec=psycopg.Connection)

provider = SqlAlchemyPooledConnectionProvider(
pool_configurator=lambda _, __: {"pool_size": 0, "max_overflow": 2}
)
mock_driver_dialect, mock_database_dialect = mock_dialects

# Create a cached pool
conn_1 = provider.connect(fake_target_connect, mock_driver_dialect, mock_database_dialect, host_info, props)
assert conn_1 is not None

# Rotate password
props[WrapperProperties.PASSWORD.name] = "TOKEN_2"

conn_2 = provider.connect(fake_target_connect, mock_driver_dialect, mock_database_dialect, host_info, props)
assert conn_2 is not None

assert captured_password == ["TOKEN_1", "TOKEN_2"]


def test_connect__password_update_different_pool_keys(host_info, mocker, mock_dialects):
captured_password_user1: list = []
captured_password_user2: list = []
props_user1 = Properties({WrapperProperties.USER.name: "user1", WrapperProperties.PASSWORD.name: "TOKEN_1"})
props_user2 = Properties({WrapperProperties.USER.name: "user2", WrapperProperties.PASSWORD.name: "TOKEN_2"})

def fake_target_connect(**kwargs):
user = kwargs.get("user")
pwd = kwargs.get("password")
if user == "user1":
captured_password_user1.append(pwd)
else:
captured_password_user2.append(pwd)
return mocker.MagicMock(spec=psycopg.Connection)

provider = SqlAlchemyPooledConnectionProvider(
pool_configurator=lambda _, __: {"pool_size": 0, "max_overflow": 3}
)
mock_driver_dialect, mock_database_dialect = mock_dialects

conn1 = provider.connect(fake_target_connect, mock_driver_dialect, mock_database_dialect, host_info, props_user1)
conn2 = provider.connect(fake_target_connect, mock_driver_dialect, mock_database_dialect, host_info, props_user2)
assert conn1 is not None
assert conn2 is not None

assert captured_password_user1 == ["TOKEN_1"]
assert captured_password_user2 == ["TOKEN_2"]

# Rotate password for user 1
props_user1[WrapperProperties.PASSWORD.name] = "TOKEN_3"
conn3 = provider.connect(fake_target_connect, mock_driver_dialect, mock_database_dialect, host_info, props_user1)
assert conn3 is not None

assert captured_password_user1 == ["TOKEN_1", "TOKEN_3"]
assert captured_password_user2 == ["TOKEN_2"]


def test_accepts_host_info(provider):
instance_url = "instance-1.XYZ.us-east-2.rds.amazonaws.com"
instance_host_info = HostInfo(instance_url)
Expand All @@ -115,18 +187,24 @@ def test_least_connections_strategy(provider, mock_pool):
writer = HostInfo("writer.XYZ.us-east-2.rds.amazonaws.com")
reader_1 = HostInfo("reader-1.XYZ.us-east-2.rds.amazonaws.com", role=HostRole.READER)
reader_2 = HostInfo("reader-2.XYZ.us-east-2.rds.amazonaws.com", role=HostRole.READER)
hosts = [writer, reader_1, reader_2]
hosts = (writer, reader_1, reader_2)
props = Properties({WrapperProperties.USER.name: "user1", WrapperProperties.PASSWORD.name: "password"})

# Create cache with 1 pool to reader_url_1_connection and 2 pools to reader_url_2_connections.
# Each pool holds 1 connection.
test_database_pools = SlidingExpirationCache()
test_database_pools.compute_if_absent(
PoolKey(reader_1.url, "user1"), lambda _: mock_pool, 10 * 60_000_000_000)
PoolKey(reader_1.url, "user1"),
lambda _: (mock_pool, Properties()),
10 * 60_000_000_000)
test_database_pools.compute_if_absent(
PoolKey(reader_2.url, "user1"), lambda _: mock_pool, 10 * 60_000_000_000)
PoolKey(reader_2.url, "user1"),
lambda _: (mock_pool, Properties()),
10 * 60_000_000_000)
test_database_pools.compute_if_absent(
PoolKey(reader_2.url, "user2"), lambda _: mock_pool, 10 * 60_000_000_000)
PoolKey(reader_2.url, "user2"),
lambda _: (mock_pool, Properties()),
10 * 60_000_000_000)

result = provider.get_host_info_by_strategy(hosts, HostRole.READER, "least_connections", props)
assert reader_1 == result
Expand All @@ -135,15 +213,21 @@ def test_least_connections_strategy(provider, mock_pool):
def test_least_connections_strategy__no_hosts_matching_role(provider):
props = Properties()
with pytest.raises(AwsWrapperError):
provider.get_host_info_by_strategy([HostInfo("writer")], HostRole.READER, "least_connections", props)
provider.get_host_info_by_strategy((HostInfo("writer"),), HostRole.READER, "least_connections", props)


def test_release_resources(provider, mocker):
pool1 = mocker.MagicMock()
pool2 = mocker.MagicMock()
test_database_pools = SlidingExpirationCache()
test_database_pools.compute_if_absent(PoolKey("url1", "user1"), lambda _: pool1, 60_000_000_000)
test_database_pools.compute_if_absent(PoolKey("url1", "user2"), lambda _: pool2, 60_000_000_000)
test_database_pools.compute_if_absent(
PoolKey("url1", "user1"),
lambda _: (pool1, Properties()),
60_000_000_000)
test_database_pools.compute_if_absent(
PoolKey("url1", "user2"),
lambda _: (pool2, Properties()),
60_000_000_000)
SqlAlchemyPooledConnectionProvider._database_pools = test_database_pools

provider.release_resources()
Expand Down
Loading