diff --git a/aws_advanced_python_wrapper/pep249_methods.py b/aws_advanced_python_wrapper/pep249_methods.py index 8d3f6487..0406c1d4 100644 --- a/aws_advanced_python_wrapper/pep249_methods.py +++ b/aws_advanced_python_wrapper/pep249_methods.py @@ -63,7 +63,7 @@ class DbApiMethod(Enum): CURSOR_NEXT = (30, "Cursor.__next__", False) CURSOR_LASTROWID = (31, "Cursor.lastrowid", False) - # AWS Advaced Python Wrapper Methods for + # AWS Advanced Python Wrapper Methods for the execution pipelines. CONNECT = (32, "connect", True) FORCE_CONNECT = (33, "force_connect", True) INIT_HOST_PROVIDER = (34, "init_host_provider", True) diff --git a/aws_advanced_python_wrapper/read_write_splitting_plugin.py b/aws_advanced_python_wrapper/read_write_splitting_plugin.py index 94e7c9b0..89ca90b6 100644 --- a/aws_advanced_python_wrapper/read_write_splitting_plugin.py +++ b/aws_advanced_python_wrapper/read_write_splitting_plugin.py @@ -28,6 +28,7 @@ ) from aws_advanced_python_wrapper.errors import (AwsWrapperError, FailoverError, + FailoverFailedError, ReadWriteSplittingError) from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.pep249_methods import DbApiMethod @@ -49,6 +50,18 @@ class ReadWriteSplittingConnectionManager(Plugin): DbApiMethod.CONNECT.method_name, DbApiMethod.NOTIFY_CONNECTION_CHANGED.method_name, DbApiMethod.CONNECTION_SET_READ_ONLY.method_name, + + DbApiMethod.CONNECTION_COMMIT.method_name, + DbApiMethod.CONNECTION_AUTOCOMMIT.method_name, + DbApiMethod.CONNECTION_AUTOCOMMIT_SETTER.method_name, + DbApiMethod.CONNECTION_IS_READ_ONLY.method_name, + DbApiMethod.CONNECTION_SET_READ_ONLY.method_name, + DbApiMethod.CONNECTION_ROLLBACK.method_name, + + DbApiMethod.CURSOR_EXECUTE.method_name, + DbApiMethod.CURSOR_FETCHONE.method_name, + DbApiMethod.CURSOR_FETCHMANY.method_name, + DbApiMethod.CURSOR_FETCHALL.method_name } _POOL_PROVIDER_CLASS_NAME = "aws_advanced_python_wrapper.sql_alchemy_connection_provider.SqlAlchemyPooledConnectionProvider" @@ -137,12 +150,15 @@ def execute( try: return execute_func() except Exception as ex: + if isinstance(ex, FailoverFailedError): + # Evict the current connection from the pool right away since it is not reusable + self._close_connections(False) if isinstance(ex, FailoverError): logger.debug( "ReadWriteSplittingPlugin.FailoverExceptionWhileExecutingCommand", method_name ) - self._close_idle_connections() + self._close_connections(True) else: logger.debug( "ReadWriteSplittingPlugin.ExceptionWhileExecutingCommand", @@ -200,6 +216,7 @@ def _initialize_writer_connection(self): ) self._set_writer_connection(conn, writer_host) self._switch_current_connection_to(conn, writer_host) + return None def _switch_connection_if_required(self, read_only: bool): current_conn = self._plugin_service.current_connection @@ -293,7 +310,7 @@ def _switch_to_writer_connection(self): "ReadWriteSplittingPlugin.NoWriterFound") if self._is_reader_conn_from_internal_pool: - self._close_connection_if_idle(self._reader_connection) + self._close_connection(self._reader_connection) logger.debug( "ReadWriteSplittingPlugin.SwitchedFromReaderToWriter", @@ -319,7 +336,7 @@ def _switch_to_reader_connection(self): ) ): # The old reader cannot be used anymore, close it. - self._close_connection_if_idle(self._reader_connection) + self._close_connection(self._reader_connection) self._in_read_write_split = True if not self._is_connection_usable(self._reader_connection, driver_dialect): @@ -345,7 +362,7 @@ def _switch_to_reader_connection(self): self._initialize_reader_connection() if self._is_writer_conn_from_internal_pool: - self._close_connection_if_idle(self._writer_connection) + self._close_connection(self._writer_connection) def _initialize_reader_connection(self): if self._connection_handler.has_no_readers(): @@ -383,38 +400,35 @@ def _initialize_reader_connection(self): "ReadWriteSplittingPlugin.SwitchedFromWriterToReader", reader_host.url ) - def _close_connection_if_idle(self, internal_conn: Optional[Connection]): + def _close_connection(self, internal_conn: Optional[Connection], close_only_if_idle: bool = True): if internal_conn is None: return current_conn = self._plugin_service.current_connection driver_dialect = self._plugin_service.driver_dialect + if close_only_if_idle and internal_conn == current_conn: + # Connection is in use, do not close + return + try: - if internal_conn != current_conn and self._is_connection_usable( - internal_conn, driver_dialect - ): + if self._is_connection_usable(internal_conn, driver_dialect): driver_dialect.execute(DbApiMethod.CONNECTION_CLOSE.method_name, lambda: internal_conn.close()) - if internal_conn == self._writer_connection: - self._writer_connection = None - self._writer_host_info = None - if internal_conn == self._reader_connection: - self._reader_connection = None - self._reader_host_info = None except Exception: # Ignore exceptions during cleanup - connection might already be dead pass + finally: + if internal_conn == self._writer_connection: + self._writer_connection = None + self._writer_host_info = None + if internal_conn == self._reader_connection: + self._reader_connection = None + self._reader_host_info = None - def _close_idle_connections(self): + def _close_connections(self, close_only_if_idle: bool = True): logger.debug("ReadWriteSplittingPlugin.ClosingInternalConnections") - self._close_connection_if_idle(self._reader_connection) - self._close_connection_if_idle(self._writer_connection) - - # Always clear cached references even if connections couldn't be closed - self._reader_connection = None - self._reader_host_info = None - self._writer_connection = None - self._writer_host_info = None + self._close_connection(self._reader_connection, close_only_if_idle) + self._close_connection(self._writer_connection, close_only_if_idle) @staticmethod def log_and_raise_exception(log_msg: str): @@ -450,7 +464,7 @@ def host_list_provider_service(self) -> Optional[HostListProviderService]: ... @host_list_provider_service.setter - def host_list_provider_service(self, new_value: int) -> None: + def host_list_provider_service(self, new_value: HostListProviderService) -> None: """The setter for the 'host_list_provider_service' attribute.""" ... diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index 8a0dd562..6d03cef3 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -462,7 +462,7 @@ UnknownDialect.AbortConnection=[UnknownDialect] abort_connection was called, but Wrapper.ConnectMethod=[Wrapper] Target driver should be a target driver's connect() method/function. Wrapper.RequiredTargetDriver=[Wrapper] Target driver is required. Wrapper.UnsupportedAttribute=[Wrapper] Target driver does not have the attribute: '{}' -Wrapper.Properties=[Wrapper] "Connection Properties: " +Wrapper.Properties=[Wrapper] "Connection Properties: {}" WriterFailoverHandler.AlreadyWriter=[WriterFailoverHandler] Current reader connection is actually a new writer connection. WriterFailoverHandler.CurrentTopologyNone=[WriterFailoverHandler] Current topology cannot be None. diff --git a/aws_advanced_python_wrapper/utils/telemetry/xray_telemetry.py b/aws_advanced_python_wrapper/utils/telemetry/xray_telemetry.py index 5ec1ac4c..728798ad 100644 --- a/aws_advanced_python_wrapper/utils/telemetry/xray_telemetry.py +++ b/aws_advanced_python_wrapper/utils/telemetry/xray_telemetry.py @@ -55,7 +55,9 @@ def set_success(self, success: bool): def set_attribute(self, key: str, value: AttributeValue): if self._trace_entity is not None: - self._trace_entity.put_annotation(key, value) + # XRay only supports str, bool, int, float - not sequences + if isinstance(value, (str, bool, int, float)): + self._trace_entity.put_annotation(key, value) def set_exception(self, exception: Exception): if self._trace_entity is not None and exception is not None: @@ -90,8 +92,7 @@ def _clone_and_close_context(context: XRayTelemetryContext, trace_level: Telemet clone._trace_entity.start_time = context._trace_entity.start_time - for key in context._trace_entity.annotations.items(): - value = context._trace_entity.annotations[key] + for key, value in context._trace_entity.annotations.items(): if key != TelemetryConst.TRACE_NAME_ANNOTATION and value is not None: clone.set_attribute(key, value) diff --git a/docs/examples/PGXRayTelemetry.py b/docs/examples/PGXRayTelemetry.py index 56f6365a..55aded07 100644 --- a/docs/examples/PGXRayTelemetry.py +++ b/docs/examples/PGXRayTelemetry.py @@ -29,7 +29,7 @@ print("-- running application") logging.basicConfig(level=logging.DEBUG) - xray_recorder.configure(sampler=LocalSampler({"version": 1, "default": {"fixed_target": 1, "rate": 1.0}})) + xray_recorder.configure(sampler=LocalSampler({"version": 1, "default": {"fixed_target": 1, "rate": 1.0}, "rules": []})) global_sdk_config.set_sdk_enabled(True) with xray_recorder.in_segment("python_xray_telemetry_app") as segment: diff --git a/tests/integration/container/test_aurora_failover.py b/tests/integration/container/test_aurora_failover.py index d8a4c2c7..18c25c9f 100644 --- a/tests/integration/container/test_aurora_failover.py +++ b/tests/integration/container/test_aurora_failover.py @@ -328,7 +328,7 @@ def test_writer_fail_within_transaction_start_transaction( @pytest.mark.parametrize("plugins", ["aurora_connection_tracker,failover", "aurora_connection_tracker,failover_v2"]) @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) - @pytest.mark.repeat(5) + @pytest.mark.repeat(5) # Run this test case a few more times since it is a flakey test def test_writer_failover_in_idle_connections( self, test_driver: TestDriver, props, conn_utils, aurora_utility, plugins): target_driver_connect = DriverHelper.get_connect_func(test_driver) diff --git a/tests/integration/container/test_custom_endpoint.py b/tests/integration/container/test_custom_endpoint.py index 2ee59b55..625e4ae8 100644 --- a/tests/integration/container/test_custom_endpoint.py +++ b/tests/integration/container/test_custom_endpoint.py @@ -61,9 +61,9 @@ def rds_utils(self): return RdsTestUtility(region) @pytest.fixture(scope='class') - def props(self): + def default_props(self): p: Properties = Properties( - {"plugins": "custom_endpoint,read_write_splitting,failover", "connect_timeout": 10_000, "autocommit": True, "cluster_id": "cluster1"}) + {"connect_timeout": 10_000, "autocommit": True, "cluster_id": "cluster1"}) features = TestEnvironment.get_current().get_features() if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in features \ @@ -77,6 +77,18 @@ def props(self): return p + @pytest.fixture(scope='class') + def props_with_failover(self, default_props): + p = default_props.copy() + p["plugins"] = "custom_endpoint,read_write_splitting,failover" + return p + + @pytest.fixture(scope='class') + def props(self, default_props): + p = default_props.copy() + p["plugins"] = "custom_endpoint,read_write_splitting" + return p + @pytest.fixture(scope='class', autouse=True) def setup_and_teardown(self): env_info = TestEnvironment.get_current().get_info() @@ -193,7 +205,7 @@ def _wait_until_endpoint_deleted(self, rds_client): else: self.logger.debug(f"Custom endpoint '{self.endpoint_id}' successfully deleted.") - def wait_until_endpoint_has_members(self, rds_client, expected_members: Set[str]): + def wait_until_endpoint_has_members(self, rds_client, expected_members: Set[str], rds_utils): start_ns = perf_counter_ns() end_ns = perf_counter_ns() + 20 * 60 * 1_000_000_000 # 20 minutes has_correct_state = False @@ -218,16 +230,17 @@ def wait_until_endpoint_has_members(self, rds_client, expected_members: Set[str] pytest.fail(f"Timed out while waiting for the custom endpoint to stabilize: " f"'{TestCustomEndpoint.endpoint_id}'.") + rds_utils.make_sure_instances_up(list(expected_members)) duration_sec = (perf_counter_ns() - start_ns) / 1_000_000_000 self.logger.debug(f"wait_until_endpoint_has_specified_members took {duration_sec} seconds.") - def test_custom_endpoint_failover(self, test_driver: TestDriver, conn_utils, props, rds_utils): - props["failover_mode"] = "reader_or_writer" + def test_custom_endpoint_failover(self, test_driver: TestDriver, conn_utils, props_with_failover, rds_utils): + props_with_failover["failover_mode"] = "reader_or_writer" target_driver_connect = DriverHelper.get_connect_func(test_driver) kwargs = conn_utils.get_connect_params() kwargs["host"] = self.endpoint_info["Endpoint"] - conn = AwsWrapperConnection.connect(target_driver_connect, **kwargs, **props) + conn = AwsWrapperConnection.connect(target_driver_connect, **kwargs, **props_with_failover) endpoint_members = self.endpoint_info["StaticMembers"] instance_id = rds_utils.query_instance_id(conn) @@ -281,7 +294,7 @@ def _setup_custom_endpoint_role(self, target_driver_connect, conn_kwargs, rds_ut self.logger.debug("Custom endpoint instance successfully set to role: " + host_role.name) def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__with_reader_as_init_conn( - self, test_driver: TestDriver, conn_utils, props, rds_utils): + self, test_driver: TestDriver, conn_utils, props_with_failover, rds_utils): ''' Will test for the following scenario: 1. Initially connect to a reader instance via the custom endpoint. @@ -297,13 +310,13 @@ def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__wit kwargs["host"] = self.endpoint_info["Endpoint"] # This setting is not required for the test, but it allows us to also test re-creation of expired monitors since # it takes more than 30 seconds to modify the cluster endpoint (usually around 140s). - props["custom_endpoint_idle_monitor_expiration_ms"] = 30_000 - props["wait_for_custom_endpoint_info_timeout_ms"] = 30_000 + props_with_failover["custom_endpoint_idle_monitor_expiration_ms"] = 30_000 + props_with_failover["wait_for_custom_endpoint_info_timeout_ms"] = 30_000 # Ensure that we are starting with a reader connection self._setup_custom_endpoint_role(target_driver_connect, kwargs, rds_utils, HostRole.READER) - conn = AwsWrapperConnection.connect(target_driver_connect, **kwargs, **props) + conn = AwsWrapperConnection.connect(target_driver_connect, **kwargs, **props_with_failover) endpoint_members = self.endpoint_info["StaticMembers"] original_reader_id = rds_utils.query_instance_id(conn) assert original_reader_id in endpoint_members @@ -323,7 +336,7 @@ def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__wit ) try: - self.wait_until_endpoint_has_members(rds_client, {original_reader_id, writer_id}) + self.wait_until_endpoint_has_members(rds_client, {original_reader_id, writer_id}, rds_utils) # We should now be able to switch to writer. conn.read_only = False @@ -339,7 +352,7 @@ def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__wit rds_client.modify_db_cluster_endpoint( DBClusterEndpointIdentifier=self.endpoint_id, StaticMembers=[original_reader_id]) - self.wait_until_endpoint_has_members(rds_client, {original_reader_id}) + self.wait_until_endpoint_has_members(rds_client, {original_reader_id}, rds_utils) # We should not be able to switch again because new_member was removed from the custom endpoint. # We are connected to the reader. Attempting to switch to the writer will throw an exception. @@ -350,16 +363,16 @@ def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__wit def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__with_writer_as_init_conn( self, test_driver: TestDriver, conn_utils, props, rds_utils): - ''' + """ Will test for the following scenario: - 1. Iniitially connect to the writer instance via the custom endpoint. + 1. Initially connect to the writer instance via the custom endpoint. 2. Attempt to switch to reader instance - should succeed, but will still use writer instance as reader. 3. Modify the custom endpoint to add a reader instance as a static member. 4. Switch to reader instance - should succeed. 5. Switch back to writer instance - should succeed. 6. Modify the custom endpoint to remove the reader instance as a static member. 7. Attempt to switch to reader instance - should fail since the custom endpoint no longer has the reader instance. - ''' + """ target_driver_connect = DriverHelper.get_connect_func(test_driver) kwargs = conn_utils.get_connect_params() @@ -401,7 +414,7 @@ def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__wit ) try: - self.wait_until_endpoint_has_members(rds_client, {original_writer_id, reader_id_to_add}) + self.wait_until_endpoint_has_members(rds_client, {original_writer_id, reader_id_to_add}, rds_utils) # We should now be able to switch to new_member. conn.read_only = True new_instance_id = rds_utils.query_instance_id(conn) @@ -414,7 +427,7 @@ def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__wit rds_client.modify_db_cluster_endpoint( DBClusterEndpointIdentifier=self.endpoint_id, StaticMembers=[original_writer_id]) - self.wait_until_endpoint_has_members(rds_client, {original_writer_id}) + self.wait_until_endpoint_has_members(rds_client, {original_writer_id}, rds_utils) # We should not be able to switch again because new_member was removed from the custom endpoint. # We are connected to the writer. Attempting to switch to the reader will not work but will intentionally diff --git a/tests/integration/container/test_read_write_splitting.py b/tests/integration/container/test_read_write_splitting.py index a5e86443..c0c2f91c 100644 --- a/tests/integration/container/test_read_write_splitting.py +++ b/tests/integration/container/test_read_write_splitting.py @@ -832,6 +832,7 @@ def test_pooled_connection__cluster_url_failover( TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED, TestEnvironmentFeatures.ABORT_CONNECTION_SUPPORTED]) @disable_on_engines([DatabaseEngine.MYSQL]) + @pytest.mark.repeat(10) # Run this test case a few more times since it is a flakey test def test_pooled_connection__failover_failed( self, test_environment: TestEnvironment, diff --git a/tests/integration/container/utils/test_environment.py b/tests/integration/container/utils/test_environment.py index e547dba2..0b303207 100644 --- a/tests/integration/container/utils/test_environment.py +++ b/tests/integration/container/utils/test_environment.py @@ -102,7 +102,7 @@ def _create() -> TestEnvironment: xray_recorder.configure(daemon_address=xray_daemon_endpoint, context_missing="IGNORE_ERROR", sampler=LocalSampler( - {"version": 1, "default": {"fixed_target": 1, "rate": 1.0}})) + {"version": 1, "default": {"fixed_target": 1, "rate": 1.0}, "rules": []})) global_sdk_config.set_sdk_enabled(True) if TestEnvironmentFeatures.TELEMETRY_METRICS_ENABLED in env.get_features(): diff --git a/tests/integration/host/build.gradle.kts b/tests/integration/host/build.gradle.kts index 7a3bc19c..c98afa21 100644 --- a/tests/integration/host/build.gradle.kts +++ b/tests/integration/host/build.gradle.kts @@ -13,32 +13,30 @@ repositories { } dependencies { - testImplementation("org.checkerframework:checker-qual:3.26.0") - testImplementation("org.junit.platform:junit-platform-commons:1.9.0") - testImplementation("org.junit.platform:junit-platform-engine:1.9.0") - testImplementation("org.junit.platform:junit-platform-launcher:1.9.0") - testImplementation("org.junit.platform:junit-platform-suite-engine:1.9.0") - testImplementation("org.junit.jupiter:junit-jupiter-api:5.9.1") - testImplementation("org.junit.jupiter:junit-jupiter-params:5.9.1") + testImplementation("org.checkerframework:checker-qual:3.49.0") + testImplementation("org.junit.platform:junit-platform-commons:1.11.4") + testImplementation("org.junit.platform:junit-platform-engine:1.11.4") + testImplementation("org.junit.platform:junit-platform-launcher:1.11.4") + testImplementation("org.junit.platform:junit-platform-suite-engine:1.11.4") + testImplementation("org.junit.jupiter:junit-jupiter-api:5.11.4") + testImplementation("org.junit.jupiter:junit-jupiter-params:5.11.4") testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine") - testImplementation("org.apache.commons:commons-dbcp2:2.9.0") - testImplementation("org.postgresql:postgresql:42.5.0") - testImplementation("mysql:mysql-connector-java:8.0.31") - testImplementation("org.springframework.boot:spring-boot-starter-jdbc:2.7.4") - testImplementation("org.mockito:mockito-inline:4.8.0") - testImplementation("software.amazon.awssdk:rds:2.20.49") - testImplementation("software.amazon.awssdk:ec2:2.20.61") - testImplementation("software.amazon.awssdk:secretsmanager:2.20.49") + testImplementation("org.apache.commons:commons-dbcp2:2.12.0") + testImplementation("org.postgresql:postgresql:42.7.5") + testImplementation("com.mysql:mysql-connector-j:9.1.0") + testImplementation("software.amazon.awssdk:rds:2.30.10") + testImplementation("software.amazon.awssdk:ec2:2.30.10") + testImplementation("software.amazon.awssdk:secretsmanager:2.30.10") // Note: all org.testcontainers dependencies should have the same version - testImplementation("org.testcontainers:testcontainers:1.21.2") - testImplementation("org.testcontainers:mysql:1.21.2") - testImplementation("org.testcontainers:postgresql:1.21.2") - testImplementation("org.testcontainers:junit-jupiter:1.21.2") - testImplementation("org.testcontainers:toxiproxy:1.21.2") - testImplementation("org.apache.poi:poi-ooxml:5.2.2") - testImplementation("org.slf4j:slf4j-simple:2.0.3") - testImplementation("com.fasterxml.jackson.core:jackson-databind:2.14.2") + testImplementation("org.testcontainers:testcontainers:2.0.3") + testImplementation("org.testcontainers:testcontainers-mysql:2.0.3") + testImplementation("org.testcontainers:testcontainers-postgresql:2.0.3") + testImplementation("org.testcontainers:testcontainers-toxiproxy:2.0.3") + testImplementation("org.testcontainers:testcontainers-junit-jupiter:2.0.3") + testImplementation("org.apache.poi:poi-ooxml:5.3.0") + testImplementation("org.slf4j:slf4j-simple:2.0.16") + testImplementation("com.fasterxml.jackson.core:jackson-databind:2.18.2") } tasks.test { diff --git a/tests/integration/host/src/test/java/integration/DriverHelper.java b/tests/integration/host/src/test/java/integration/DriverHelper.java index b4c4f679..65ed4acf 100644 --- a/tests/integration/host/src/test/java/integration/DriverHelper.java +++ b/tests/integration/host/src/test/java/integration/DriverHelper.java @@ -20,7 +20,6 @@ import java.sql.DriverManager; import java.sql.SQLException; import java.util.logging.Logger; -import org.testcontainers.shaded.org.apache.commons.lang3.NotImplementedException; public class DriverHelper { @@ -33,7 +32,7 @@ public static String getDriverProtocol(DatabaseEngine databaseEngine) { case PG: return "jdbc:postgresql://"; default: - throw new NotImplementedException(databaseEngine.toString()); + throw new UnsupportedOperationException(databaseEngine.toString()); } } @@ -44,7 +43,7 @@ public static String getDriverClassname(DatabaseEngine databaseEngine) { case PG: return getDriverClassname(TestDriver.PG); default: - throw new NotImplementedException(databaseEngine.toString()); + throw new UnsupportedOperationException(databaseEngine.toString()); } } @@ -55,7 +54,7 @@ public static String getDriverClassname(TestDriver testDriver) { case PG: return "org.postgresql.Driver"; default: - throw new NotImplementedException(testDriver.toString()); + throw new UnsupportedOperationException(testDriver.toString()); } } diff --git a/tests/integration/host/src/test/java/integration/host/TestEnvironment.java b/tests/integration/host/src/test/java/integration/host/TestEnvironment.java index 2eb36682..7179c1ec 100644 --- a/tests/integration/host/src/test/java/integration/host/TestEnvironment.java +++ b/tests/integration/host/src/test/java/integration/host/TestEnvironment.java @@ -50,7 +50,6 @@ import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.Network; import org.testcontainers.containers.ToxiproxyContainer; -import org.testcontainers.shaded.org.apache.commons.lang3.NotImplementedException; import software.amazon.awssdk.services.rds.model.BlueGreenDeployment; import software.amazon.awssdk.services.rds.model.DBCluster; import software.amazon.awssdk.services.rds.model.DBInstance; @@ -149,7 +148,7 @@ public static TestEnvironment build(TestEnvironmentRequest request) throws IOExc break; default: - throw new NotImplementedException(request.getDatabaseEngineDeployment().toString()); + throw new UnsupportedOperationException(request.getDatabaseEngineDeployment().toString()); } if (request.getFeatures().contains(TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED)) { @@ -272,7 +271,7 @@ private static TestEnvironment createAuroraOrMultiAzEnvironment(TestEnvironmentR configureIamAccess(env); break; default: - throw new NotImplementedException(request.getDatabaseEngineDeployment().toString()); + throw new UnsupportedOperationException(request.getDatabaseEngineDeployment().toString()); } return env; @@ -404,7 +403,7 @@ private static void createDatabaseContainers(TestEnvironment env) { } break; default: - throw new NotImplementedException(env.info.getRequest().getDatabaseInstances().toString()); + throw new UnsupportedOperationException(env.info.getRequest().getDatabaseInstances().toString()); } switch (env.info.getRequest().getDatabaseEngine()) { @@ -453,7 +452,7 @@ private static void createDatabaseContainers(TestEnvironment env) { break; default: - throw new NotImplementedException(env.info.getRequest().getDatabaseEngine().toString()); + throw new UnsupportedOperationException(env.info.getRequest().getDatabaseEngine().toString()); } } @@ -489,7 +488,7 @@ private static void createDbCluster(TestEnvironment env) { createDbCluster(env, env.numOfInstances); break; default: - throw new NotImplementedException(env.info.getRequest().getDatabaseEngine().toString()); + throw new UnsupportedOperationException(env.info.getRequest().getDatabaseEngine().toString()); } } @@ -852,7 +851,7 @@ private static String getDbEngine(TestEnvironmentRequest request) { case RDS_MULTI_AZ_INSTANCE: return getRdsEngine(request); default: - throw new NotImplementedException(request.getDatabaseEngineDeployment().toString()); + throw new UnsupportedOperationException(request.getDatabaseEngineDeployment().toString()); } } @@ -863,7 +862,7 @@ private static String getAuroraDbEngine(TestEnvironmentRequest request) { case PG: return "aurora-postgresql"; default: - throw new NotImplementedException(request.getDatabaseEngine().toString()); + throw new UnsupportedOperationException(request.getDatabaseEngine().toString()); } } @@ -874,7 +873,7 @@ private static String getRdsEngine(TestEnvironmentRequest request) { case PG: return "postgres"; default: - throw new NotImplementedException(request.getDatabaseEngine().toString()); + throw new UnsupportedOperationException(request.getDatabaseEngine().toString()); } } @@ -889,7 +888,7 @@ private static String getDbEngineVersion(String engineName, TestEnvironment env) systemPropertyVersion = config.pgVersion; break; default: - throw new NotImplementedException(request.getDatabaseEngine().toString()); + throw new UnsupportedOperationException(request.getDatabaseEngine().toString()); } return findEngineVersion(env, engineName, systemPropertyVersion); } @@ -919,7 +918,7 @@ private static int getPort(TestEnvironmentRequest request) { case PG: return 5432; default: - throw new NotImplementedException(request.getDatabaseEngine().toString()); + throw new UnsupportedOperationException(request.getDatabaseEngine().toString()); } } @@ -1148,7 +1147,7 @@ private static String getContainerBaseImageName(TestEnvironmentRequest request) case PYTHON_3_13: return "python:3.13"; default: - throw new NotImplementedException(request.getTargetPythonVersion().toString()); + throw new UnsupportedOperationException(request.getTargetPythonVersion().toString()); } } @@ -1315,7 +1314,7 @@ public void close() throws Exception { // do nothing break; default: - throw new NotImplementedException(this.info.getRequest().getDatabaseEngineDeployment().toString()); + throw new UnsupportedOperationException(this.info.getRequest().getDatabaseEngineDeployment().toString()); } } @@ -1490,7 +1489,7 @@ private static void preCreateEnvironment(int currentEnvIndex) { configureIamAccess(env); break; default: - throw new NotImplementedException(env.info.getRequest().getDatabaseEngineDeployment().toString()); + throw new UnsupportedOperationException(env.info.getRequest().getDatabaseEngineDeployment().toString()); } return env; diff --git a/tests/integration/host/src/test/java/integration/util/AuroraTestUtility.java b/tests/integration/host/src/test/java/integration/util/AuroraTestUtility.java index a04d44a9..1b8684ed 100644 --- a/tests/integration/host/src/test/java/integration/util/AuroraTestUtility.java +++ b/tests/integration/host/src/test/java/integration/util/AuroraTestUtility.java @@ -49,7 +49,6 @@ import java.util.logging.Logger; import java.util.stream.Collectors; import org.checkerframework.checker.nullness.qual.Nullable; -import org.testcontainers.shaded.org.apache.commons.lang3.NotImplementedException; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; @@ -897,7 +896,7 @@ public String getDbInstanceClass(TestEnvironmentRequest request) { case RDS_MULTI_AZ_CLUSTER: return "db.m5d.large"; default: - throw new NotImplementedException(request.getDatabaseEngineDeployment().toString()); + throw new UnsupportedOperationException(request.getDatabaseEngineDeployment().toString()); } } diff --git a/tests/integration/host/src/test/java/integration/util/ContainerHelper.java b/tests/integration/host/src/test/java/integration/util/ContainerHelper.java index 5895ebe7..4d7c7b02 100644 --- a/tests/integration/host/src/test/java/integration/util/ContainerHelper.java +++ b/tests/integration/host/src/test/java/integration/util/ContainerHelper.java @@ -49,10 +49,10 @@ public class ContainerHelper { - private static final String MYSQL_CONTAINER_IMAGE_NAME = "mysql:8.0.36"; + private static final String MYSQL_CONTAINER_IMAGE_NAME = "mysql:lts"; private static final String POSTGRES_CONTAINER_IMAGE_NAME = "postgres:latest"; private static final DockerImageName TOXIPROXY_IMAGE = - DockerImageName.parse("ghcr.io/shopify/toxiproxy:2.11.0"); + DockerImageName.parse("ghcr.io/shopify/toxiproxy:2.12.0"); private static final int PROXY_CONTROL_PORT = 8474; private static final int PROXY_PORT = 8666; @@ -317,12 +317,10 @@ public MySQLContainer createMysqlContainer( "--local_infile=1", "--max_allowed_packet=40M", "--max-connections=2048", - "--secure-file-priv=/var/lib/mysql", "--log_bin_trust_function_creators=1", "--character-set-server=utf8mb4", "--collation-server=utf8mb4_0900_as_cs", - "--skip-character-set-client-handshake", - "--log-error-verbosity=4"); + "--log-error-verbosity=3"); } public PostgreSQLContainer createPostgresContainer( diff --git a/tests/unit/test_read_write_splitting_plugin.py b/tests/unit/test_read_write_splitting_plugin.py index d80f4689..0308d565 100644 --- a/tests/unit/test_read_write_splitting_plugin.py +++ b/tests/unit/test_read_write_splitting_plugin.py @@ -327,7 +327,7 @@ def connect_side_effect(host: HostInfo, props, plugin): ) provider = SqlAlchemyPooledConnectionProvider( - lambda _, __: {"pool_size": 3}, None, 180000000000, 600000000000 # 3 minutes + lambda _, __: {"pool_size": 3}, None, None, 180000000000, 600000000000 # 3 minutes ) # 10 minutes conn_provider_manager_mock = mocker.MagicMock() @@ -345,7 +345,7 @@ def connect_side_effect(host: HostInfo, props, plugin): ) plugin = SimpleReadWriteSplittingPlugin(plugin_service_mock, srw_props) - spy = mocker.spy(plugin, "_close_connection_if_idle") + spy = mocker.spy(plugin, "_close_connection") plugin._switch_connection_if_required(True) plugin._switch_connection_if_required(False) @@ -373,7 +373,7 @@ def connect_side_effect(host: HostInfo, props, plugin): ) provider = SqlAlchemyPooledConnectionProvider( - lambda _, __: {"pool_size": 3}, None, 180000000000, 600000000000 # 3 minutes + lambda _, __: {"pool_size": 3}, None, None, 180000000000, 600000000000 # 3 minutes ) # 10 minutes conn_provider_manager_mock = mocker.MagicMock() @@ -391,7 +391,7 @@ def connect_side_effect(host: HostInfo, props, plugin): ) plugin = SimpleReadWriteSplittingPlugin(plugin_service_mock, srw_props) - spy = mocker.spy(plugin, "_close_connection_if_idle") + spy = mocker.spy(plugin, "_close_connection") plugin._switch_connection_if_required(True) plugin._switch_connection_if_required(False) @@ -476,7 +476,7 @@ def test_connect_non_initial_connection_read_write_splitting( connect_func_mock.assert_called() -def test_set_read_only_true_one_host_read_write_splitting(plugin_service_mock, read_write_splitting_plugin): +def test_set_read_only_true_one_host_read_write_splitting(plugin_service_mock, read_write_splitting_plugin, writer_conn_mock): plugin_service_mock.hosts = [writer_host] read_write_splitting_plugin._writer_connection = writer_conn_mock @@ -490,7 +490,7 @@ def test_set_read_only_true_one_host_read_write_splitting(plugin_service_mock, r def test_connect_error_updating_host_read_write_splitting( - plugin_service_mock, read_write_splitting_plugin, host_list_provider_service_mock, connect_func_mock, mocker): + plugin_service_mock, read_write_splitting_plugin, host_list_provider_service_mock, connect_func_mock, mocker, reader_conn_mock): def get_host_role_side_effect(conn): if conn == reader_conn_mock: return None @@ -516,7 +516,7 @@ def get_host_role_side_effect(conn): # Tests for the Simple Read/Write Splitting Plugin -def test_set_read_only_true_srw(srw_plugin, plugin_service_mock, reader_conn_mock): +def test_set_read_only_true_srw(srw_plugin, plugin_service_mock, reader_conn_mock, writer_conn_mock): plugin_service_mock.current_connection = writer_conn_mock plugin_service_mock.connect.return_value = reader_conn_mock @@ -534,7 +534,7 @@ def test_set_read_only_true_srw(srw_plugin, plugin_service_mock, reader_conn_moc def test_set_read_only_false_srw( - srw_plugin, plugin_service_mock, reader_conn_mock, writer_conn_mock,): + srw_plugin, plugin_service_mock, reader_conn_mock, writer_conn_mock): plugin_service_mock.current_connection = reader_conn_mock plugin_service_mock.connect.return_value = writer_conn_mock @@ -719,7 +719,7 @@ def test_connect_non_rds_cluster_endpoint_srw( def test_connect_non_rds_cluster_endpoint_with_verification_srw( - plugin_service_mock, connect_func_mock, writer_conn_mock, mocker,): + plugin_service_mock, connect_func_mock, writer_conn_mock, mocker, host_list_provider_service_mock): custom_host = HostInfo( host="custom-db.example.com", port=TEST_PORT, role=HostRole.WRITER )