Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aws_advanced_python_wrapper/pep249_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class DbApiMethod(Enum):
CURSOR_NEXT = (30, "Cursor.__next__", False)
CURSOR_LASTROWID = (31, "Cursor.lastrowid", False)

# AWS Advaced Python Wrapper Methods for
# AWS Advanced Python Wrapper Methods for the execution pipelines.
CONNECT = (32, "connect", True)
FORCE_CONNECT = (33, "force_connect", True)
INIT_HOST_PROVIDER = (34, "init_host_provider", True)
Expand Down
62 changes: 38 additions & 24 deletions aws_advanced_python_wrapper/read_write_splitting_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)

from aws_advanced_python_wrapper.errors import (AwsWrapperError, FailoverError,
FailoverFailedError,
ReadWriteSplittingError)
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
Expand All @@ -49,6 +50,18 @@ class ReadWriteSplittingConnectionManager(Plugin):
DbApiMethod.CONNECT.method_name,
DbApiMethod.NOTIFY_CONNECTION_CHANGED.method_name,
DbApiMethod.CONNECTION_SET_READ_ONLY.method_name,

DbApiMethod.CONNECTION_COMMIT.method_name,
DbApiMethod.CONNECTION_AUTOCOMMIT.method_name,
DbApiMethod.CONNECTION_AUTOCOMMIT_SETTER.method_name,
DbApiMethod.CONNECTION_IS_READ_ONLY.method_name,
DbApiMethod.CONNECTION_SET_READ_ONLY.method_name,
DbApiMethod.CONNECTION_ROLLBACK.method_name,

DbApiMethod.CURSOR_EXECUTE.method_name,
DbApiMethod.CURSOR_FETCHONE.method_name,
DbApiMethod.CURSOR_FETCHMANY.method_name,
DbApiMethod.CURSOR_FETCHALL.method_name
}
_POOL_PROVIDER_CLASS_NAME = "aws_advanced_python_wrapper.sql_alchemy_connection_provider.SqlAlchemyPooledConnectionProvider"

Expand Down Expand Up @@ -137,12 +150,15 @@ def execute(
try:
return execute_func()
except Exception as ex:
if isinstance(ex, FailoverFailedError):
# Evict the current connection from the pool right away since it is not reusable
self._close_connections(False)
if isinstance(ex, FailoverError):
logger.debug(
"ReadWriteSplittingPlugin.FailoverExceptionWhileExecutingCommand",
method_name
)
self._close_idle_connections()
self._close_connections(True)
else:
logger.debug(
"ReadWriteSplittingPlugin.ExceptionWhileExecutingCommand",
Expand Down Expand Up @@ -200,6 +216,7 @@ def _initialize_writer_connection(self):
)
self._set_writer_connection(conn, writer_host)
self._switch_current_connection_to(conn, writer_host)
return None

def _switch_connection_if_required(self, read_only: bool):
current_conn = self._plugin_service.current_connection
Expand Down Expand Up @@ -293,7 +310,7 @@ def _switch_to_writer_connection(self):
"ReadWriteSplittingPlugin.NoWriterFound")

if self._is_reader_conn_from_internal_pool:
self._close_connection_if_idle(self._reader_connection)
self._close_connection(self._reader_connection)

logger.debug(
"ReadWriteSplittingPlugin.SwitchedFromReaderToWriter",
Expand All @@ -319,7 +336,7 @@ def _switch_to_reader_connection(self):
)
):
# The old reader cannot be used anymore, close it.
self._close_connection_if_idle(self._reader_connection)
self._close_connection(self._reader_connection)

self._in_read_write_split = True
if not self._is_connection_usable(self._reader_connection, driver_dialect):
Expand All @@ -345,7 +362,7 @@ def _switch_to_reader_connection(self):
self._initialize_reader_connection()

if self._is_writer_conn_from_internal_pool:
self._close_connection_if_idle(self._writer_connection)
self._close_connection(self._writer_connection)

def _initialize_reader_connection(self):
if self._connection_handler.has_no_readers():
Expand Down Expand Up @@ -383,38 +400,35 @@ def _initialize_reader_connection(self):
"ReadWriteSplittingPlugin.SwitchedFromWriterToReader", reader_host.url
)

def _close_connection_if_idle(self, internal_conn: Optional[Connection]):
def _close_connection(self, internal_conn: Optional[Connection], close_only_if_idle: bool = True):
if internal_conn is None:
return

current_conn = self._plugin_service.current_connection
driver_dialect = self._plugin_service.driver_dialect

if close_only_if_idle and internal_conn == current_conn:
# Connection is in use, do not close
return

try:
if internal_conn != current_conn and self._is_connection_usable(
internal_conn, driver_dialect
):
if self._is_connection_usable(internal_conn, driver_dialect):
driver_dialect.execute(DbApiMethod.CONNECTION_CLOSE.method_name, lambda: internal_conn.close())
if internal_conn == self._writer_connection:
self._writer_connection = None
self._writer_host_info = None
if internal_conn == self._reader_connection:
self._reader_connection = None
self._reader_host_info = None
except Exception:
# Ignore exceptions during cleanup - connection might already be dead
pass
finally:
if internal_conn == self._writer_connection:
self._writer_connection = None
self._writer_host_info = None
if internal_conn == self._reader_connection:
self._reader_connection = None
self._reader_host_info = None

def _close_idle_connections(self):
def _close_connections(self, close_only_if_idle: bool = True):
logger.debug("ReadWriteSplittingPlugin.ClosingInternalConnections")
self._close_connection_if_idle(self._reader_connection)
self._close_connection_if_idle(self._writer_connection)

# Always clear cached references even if connections couldn't be closed
self._reader_connection = None
self._reader_host_info = None
self._writer_connection = None
self._writer_host_info = None
self._close_connection(self._reader_connection, close_only_if_idle)
self._close_connection(self._writer_connection, close_only_if_idle)

@staticmethod
def log_and_raise_exception(log_msg: str):
Expand Down Expand Up @@ -450,7 +464,7 @@ def host_list_provider_service(self) -> Optional[HostListProviderService]:
...

@host_list_provider_service.setter
def host_list_provider_service(self, new_value: int) -> None:
def host_list_provider_service(self, new_value: HostListProviderService) -> None:
"""The setter for the 'host_list_provider_service' attribute."""
...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ UnknownDialect.AbortConnection=[UnknownDialect] abort_connection was called, but
Wrapper.ConnectMethod=[Wrapper] Target driver should be a target driver's connect() method/function.
Wrapper.RequiredTargetDriver=[Wrapper] Target driver is required.
Wrapper.UnsupportedAttribute=[Wrapper] Target driver does not have the attribute: '{}'
Wrapper.Properties=[Wrapper] "Connection Properties: "
Wrapper.Properties=[Wrapper] "Connection Properties: {}"

WriterFailoverHandler.AlreadyWriter=[WriterFailoverHandler] Current reader connection is actually a new writer connection.
WriterFailoverHandler.CurrentTopologyNone=[WriterFailoverHandler] Current topology cannot be None.
Expand Down
7 changes: 4 additions & 3 deletions aws_advanced_python_wrapper/utils/telemetry/xray_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def set_success(self, success: bool):

def set_attribute(self, key: str, value: AttributeValue):
if self._trace_entity is not None:
self._trace_entity.put_annotation(key, value)
# XRay only supports str, bool, int, float - not sequences
if isinstance(value, (str, bool, int, float)):
self._trace_entity.put_annotation(key, value)

def set_exception(self, exception: Exception):
if self._trace_entity is not None and exception is not None:
Expand Down Expand Up @@ -90,8 +92,7 @@ def _clone_and_close_context(context: XRayTelemetryContext, trace_level: Telemet

clone._trace_entity.start_time = context._trace_entity.start_time

for key in context._trace_entity.annotations.items():
value = context._trace_entity.annotations[key]
for key, value in context._trace_entity.annotations.items():
if key != TelemetryConst.TRACE_NAME_ANNOTATION and value is not None:
clone.set_attribute(key, value)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/PGXRayTelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
print("-- running application")
logging.basicConfig(level=logging.DEBUG)

xray_recorder.configure(sampler=LocalSampler({"version": 1, "default": {"fixed_target": 1, "rate": 1.0}}))
xray_recorder.configure(sampler=LocalSampler({"version": 1, "default": {"fixed_target": 1, "rate": 1.0}, "rules": []}))
global_sdk_config.set_sdk_enabled(True)

with xray_recorder.in_segment("python_xray_telemetry_app") as segment:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/container/test_aurora_failover.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def test_writer_fail_within_transaction_start_transaction(

@pytest.mark.parametrize("plugins", ["aurora_connection_tracker,failover", "aurora_connection_tracker,failover_v2"])
@enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED])
@pytest.mark.repeat(5)
@pytest.mark.repeat(5) # Run this test case a few more times since it is a flakey test
def test_writer_failover_in_idle_connections(
self, test_driver: TestDriver, props, conn_utils, aurora_utility, plugins):
target_driver_connect = DriverHelper.get_connect_func(test_driver)
Expand Down
47 changes: 30 additions & 17 deletions tests/integration/container/test_custom_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def rds_utils(self):
return RdsTestUtility(region)

@pytest.fixture(scope='class')
def props(self):
def default_props(self):
p: Properties = Properties(
{"plugins": "custom_endpoint,read_write_splitting,failover", "connect_timeout": 10_000, "autocommit": True, "cluster_id": "cluster1"})
{"connect_timeout": 10_000, "autocommit": True, "cluster_id": "cluster1"})

features = TestEnvironment.get_current().get_features()
if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in features \
Expand All @@ -77,6 +77,18 @@ def props(self):

return p

@pytest.fixture(scope='class')
def props_with_failover(self, default_props):
p = default_props.copy()
p["plugins"] = "custom_endpoint,read_write_splitting,failover"
return p

@pytest.fixture(scope='class')
def props(self, default_props):
p = default_props.copy()
p["plugins"] = "custom_endpoint,read_write_splitting"
return p

@pytest.fixture(scope='class', autouse=True)
def setup_and_teardown(self):
env_info = TestEnvironment.get_current().get_info()
Expand Down Expand Up @@ -193,7 +205,7 @@ def _wait_until_endpoint_deleted(self, rds_client):
else:
self.logger.debug(f"Custom endpoint '{self.endpoint_id}' successfully deleted.")

def wait_until_endpoint_has_members(self, rds_client, expected_members: Set[str]):
def wait_until_endpoint_has_members(self, rds_client, expected_members: Set[str], rds_utils):
start_ns = perf_counter_ns()
end_ns = perf_counter_ns() + 20 * 60 * 1_000_000_000 # 20 minutes
has_correct_state = False
Expand All @@ -218,16 +230,17 @@ def wait_until_endpoint_has_members(self, rds_client, expected_members: Set[str]
pytest.fail(f"Timed out while waiting for the custom endpoint to stabilize: "
f"'{TestCustomEndpoint.endpoint_id}'.")

rds_utils.make_sure_instances_up(list(expected_members))
duration_sec = (perf_counter_ns() - start_ns) / 1_000_000_000
self.logger.debug(f"wait_until_endpoint_has_specified_members took {duration_sec} seconds.")

def test_custom_endpoint_failover(self, test_driver: TestDriver, conn_utils, props, rds_utils):
props["failover_mode"] = "reader_or_writer"
def test_custom_endpoint_failover(self, test_driver: TestDriver, conn_utils, props_with_failover, rds_utils):
props_with_failover["failover_mode"] = "reader_or_writer"

target_driver_connect = DriverHelper.get_connect_func(test_driver)
kwargs = conn_utils.get_connect_params()
kwargs["host"] = self.endpoint_info["Endpoint"]
conn = AwsWrapperConnection.connect(target_driver_connect, **kwargs, **props)
conn = AwsWrapperConnection.connect(target_driver_connect, **kwargs, **props_with_failover)

endpoint_members = self.endpoint_info["StaticMembers"]
instance_id = rds_utils.query_instance_id(conn)
Expand Down Expand Up @@ -281,7 +294,7 @@ def _setup_custom_endpoint_role(self, target_driver_connect, conn_kwargs, rds_ut
self.logger.debug("Custom endpoint instance successfully set to role: " + host_role.name)

def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__with_reader_as_init_conn(
self, test_driver: TestDriver, conn_utils, props, rds_utils):
self, test_driver: TestDriver, conn_utils, props_with_failover, rds_utils):
'''
Will test for the following scenario:
1. Initially connect to a reader instance via the custom endpoint.
Expand All @@ -297,13 +310,13 @@ def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__wit
kwargs["host"] = self.endpoint_info["Endpoint"]
# This setting is not required for the test, but it allows us to also test re-creation of expired monitors since
# it takes more than 30 seconds to modify the cluster endpoint (usually around 140s).
props["custom_endpoint_idle_monitor_expiration_ms"] = 30_000
props["wait_for_custom_endpoint_info_timeout_ms"] = 30_000
props_with_failover["custom_endpoint_idle_monitor_expiration_ms"] = 30_000
props_with_failover["wait_for_custom_endpoint_info_timeout_ms"] = 30_000

# Ensure that we are starting with a reader connection
self._setup_custom_endpoint_role(target_driver_connect, kwargs, rds_utils, HostRole.READER)

conn = AwsWrapperConnection.connect(target_driver_connect, **kwargs, **props)
conn = AwsWrapperConnection.connect(target_driver_connect, **kwargs, **props_with_failover)
endpoint_members = self.endpoint_info["StaticMembers"]
original_reader_id = rds_utils.query_instance_id(conn)
assert original_reader_id in endpoint_members
Expand All @@ -323,7 +336,7 @@ def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__wit
)

try:
self.wait_until_endpoint_has_members(rds_client, {original_reader_id, writer_id})
self.wait_until_endpoint_has_members(rds_client, {original_reader_id, writer_id}, rds_utils)

# We should now be able to switch to writer.
conn.read_only = False
Expand All @@ -339,7 +352,7 @@ def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__wit
rds_client.modify_db_cluster_endpoint(
DBClusterEndpointIdentifier=self.endpoint_id,
StaticMembers=[original_reader_id])
self.wait_until_endpoint_has_members(rds_client, {original_reader_id})
self.wait_until_endpoint_has_members(rds_client, {original_reader_id}, rds_utils)

# We should not be able to switch again because new_member was removed from the custom endpoint.
# We are connected to the reader. Attempting to switch to the writer will throw an exception.
Expand All @@ -350,16 +363,16 @@ def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__wit

def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__with_writer_as_init_conn(
self, test_driver: TestDriver, conn_utils, props, rds_utils):
'''
"""
Will test for the following scenario:
1. Iniitially connect to the writer instance via the custom endpoint.
1. Initially connect to the writer instance via the custom endpoint.
2. Attempt to switch to reader instance - should succeed, but will still use writer instance as reader.
3. Modify the custom endpoint to add a reader instance as a static member.
4. Switch to reader instance - should succeed.
5. Switch back to writer instance - should succeed.
6. Modify the custom endpoint to remove the reader instance as a static member.
7. Attempt to switch to reader instance - should fail since the custom endpoint no longer has the reader instance.
'''
"""

target_driver_connect = DriverHelper.get_connect_func(test_driver)
kwargs = conn_utils.get_connect_params()
Expand Down Expand Up @@ -401,7 +414,7 @@ def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__wit
)

try:
self.wait_until_endpoint_has_members(rds_client, {original_writer_id, reader_id_to_add})
self.wait_until_endpoint_has_members(rds_client, {original_writer_id, reader_id_to_add}, rds_utils)
# We should now be able to switch to new_member.
conn.read_only = True
new_instance_id = rds_utils.query_instance_id(conn)
Expand All @@ -414,7 +427,7 @@ def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__wit
rds_client.modify_db_cluster_endpoint(
DBClusterEndpointIdentifier=self.endpoint_id,
StaticMembers=[original_writer_id])
self.wait_until_endpoint_has_members(rds_client, {original_writer_id})
self.wait_until_endpoint_has_members(rds_client, {original_writer_id}, rds_utils)

# We should not be able to switch again because new_member was removed from the custom endpoint.
# We are connected to the writer. Attempting to switch to the reader will not work but will intentionally
Expand Down
1 change: 1 addition & 0 deletions tests/integration/container/test_read_write_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,7 @@ def test_pooled_connection__cluster_url_failover(
TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED,
TestEnvironmentFeatures.ABORT_CONNECTION_SUPPORTED])
@disable_on_engines([DatabaseEngine.MYSQL])
@pytest.mark.repeat(10) # Run this test case a few more times since it is a flakey test
def test_pooled_connection__failover_failed(
self,
test_environment: TestEnvironment,
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/container/utils/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _create() -> TestEnvironment:
xray_recorder.configure(daemon_address=xray_daemon_endpoint,
context_missing="IGNORE_ERROR",
sampler=LocalSampler(
{"version": 1, "default": {"fixed_target": 1, "rate": 1.0}}))
{"version": 1, "default": {"fixed_target": 1, "rate": 1.0}, "rules": []}))
global_sdk_config.set_sdk_enabled(True)

if TestEnvironmentFeatures.TELEMETRY_METRICS_ENABLED in env.get_features():
Expand Down
Loading