diff --git a/aws_advanced_python_wrapper/read_write_splitting_plugin.py b/aws_advanced_python_wrapper/read_write_splitting_plugin.py index 89ca90b6..1ef1e4ed 100644 --- a/aws_advanced_python_wrapper/read_write_splitting_plugin.py +++ b/aws_advanced_python_wrapper/read_write_splitting_plugin.py @@ -14,8 +14,9 @@ from __future__ import annotations +from abc import abstractmethod from copy import deepcopy -from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Set, Tuple +from typing import TYPE_CHECKING, Any, Callable, Optional, Set, Tuple if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect @@ -42,9 +43,7 @@ logger = Logger(__name__) -class ReadWriteSplittingConnectionManager(Plugin): - """Base class that manages connection switching logic.""" - +class AbstractReadWriteSplittingPlugin(Plugin): _SUBSCRIBED_METHODS: Set[str] = { DbApiMethod.INIT_HOST_PROVIDER.method_name, DbApiMethod.CONNECT.method_name, @@ -69,11 +68,9 @@ def __init__( self, plugin_service: PluginService, props: Properties, - connection_handler: ReadWriteConnectionHandler, ): self._plugin_service: PluginService = plugin_service self._properties: Properties = props - self._connection_handler: ReadWriteConnectionHandler = connection_handler self._writer_connection: Optional[Connection] = None self._reader_connection: Optional[Connection] = None self._writer_host_info: Optional[HostInfo] = None @@ -84,6 +81,7 @@ def __init__( self._is_reader_conn_from_internal_pool: bool = False self._is_writer_conn_from_internal_pool: bool = False self._in_read_write_split: bool = False + self._host_list_provider_service: Optional[HostListProviderService] = None @property def subscribed_methods(self) -> Set[str]: @@ -95,21 +93,9 @@ def init_host_provider( host_list_provider_service: HostListProviderService, init_host_provider_func: Callable, ): - self._connection_handler.host_list_provider_service = host_list_provider_service + self._host_list_provider_service = host_list_provider_service init_host_provider_func() - def connect( - self, - target_driver_func: Callable, - driver_dialect: DriverDialect, - host_info: HostInfo, - props: Properties, - is_initial_connection: bool, - connect_func: Callable, - ) -> Connection: - return self._connection_handler.get_verified_initial_connection( - host_info, is_initial_connection, lambda x: self._plugin_service.connect(x, props, self), connect_func) - def notify_connection_changed( self, changes: Set[ConnectionEvent] ) -> OldConnectionSuggestedAction: @@ -172,13 +158,9 @@ def _update_internal_connection_info(self): if current_conn is None or current_host is None: return - if self._connection_handler.should_update_writer_with_current_conn( - current_conn, current_host, self._writer_connection - ): + if self._should_update_writer_connection(current_conn, current_host): self._set_writer_connection(current_conn, current_host) - elif self._connection_handler.should_update_reader_with_current_conn( - current_conn, current_host, self._reader_connection - ): + elif self._should_update_reader_connection(current_conn, current_host): self._set_reader_connection(current_conn, current_host) def _set_writer_connection( @@ -199,25 +181,6 @@ def _set_reader_connection( "ReadWriteSplittingPlugin.SetReaderConnection", reader_host_info.url ) - def _initialize_writer_connection(self): - conn, writer_host = self._connection_handler.open_new_writer_connection(lambda x: self._plugin_service.connect(x, self._properties, self)) - if conn is None: - self.log_and_raise_exception( - "ReadWriteSplittingPlugin.FailedToConnectToWriter" - ) - return None - - provider = self._conn_provider_manager.get_connection_provider( - writer_host, self._properties - ) - self._is_writer_conn_from_internal_pool = ( - ReadWriteSplittingConnectionManager._POOL_PROVIDER_CLASS_NAME - in str(type(provider)) - ) - self._set_writer_connection(conn, writer_host) - self._switch_current_connection_to(conn, writer_host) - return None - def _switch_connection_if_required(self, read_only: bool): current_conn = self._plugin_service.current_connection driver_dialect = self._plugin_service.driver_dialect @@ -231,9 +194,7 @@ def _switch_connection_if_required(self, read_only: bool): "ReadWriteSplittingPlugin.SetReadOnlyOnClosedConnection" ) - self._connection_handler.refresh_and_store_host_list( - current_conn, driver_dialect - ) + self._refresh_and_store_topology(current_conn) current_host = self._plugin_service.current_host_info if current_host is None: @@ -243,7 +204,7 @@ def _switch_connection_if_required(self, read_only: bool): if read_only: if ( not self._plugin_service.is_in_transaction - and not self._connection_handler.is_reader_host(current_host) + and not self._is_reader(current_host) ): try: self._switch_to_reader_connection() @@ -258,7 +219,7 @@ def _switch_connection_if_required(self, read_only: bool): "ReadWriteSplittingPlugin.FallbackToCurrentConnection", current_host.url, ) - elif not self._connection_handler.is_writer_host(current_host): + elif not self._is_writer(current_host): if self._plugin_service.is_in_transaction: self.log_and_raise_exception( "ReadWriteSplittingPlugin.SetReadOnlyFalseInTransaction" @@ -290,24 +251,19 @@ def _switch_to_writer_connection(self): driver_dialect = self._plugin_service.driver_dialect if ( current_host is not None - and self._connection_handler.is_writer_host(current_host) + and self._is_writer(current_host) and self._is_connection_usable(current_conn, driver_dialect) ): # Already connected to the intended writer. return - self._writer_host_info = self._connection_handler.get_writer_host_info() self._in_read_write_split = True if not self._is_connection_usable(self._writer_connection, driver_dialect): self._initialize_writer_connection() elif self._writer_connection is not None and self._writer_host_info is not None: - if self._connection_handler.can_host_be_used(self._writer_host_info): - self._switch_current_connection_to( - self._writer_connection, self._writer_host_info - ) - else: - ReadWriteSplittingConnectionManager.log_and_raise_exception( - "ReadWriteSplittingPlugin.NoWriterFound") + self._switch_current_connection_to( + self._writer_connection, self._writer_host_info + ) if self._is_reader_conn_from_internal_pool: self._close_connection(self._reader_connection) @@ -323,20 +279,13 @@ def _switch_to_reader_connection(self): driver_dialect = self._plugin_service.driver_dialect if ( current_host is not None - and self._connection_handler.is_reader_host(current_host) + and self._is_reader(current_host) and self._is_connection_usable(current_conn, driver_dialect) ): # Already connected to the intended reader. return - if ( - self._reader_connection is not None - and not self._connection_handler.can_host_be_used( - self._reader_host_info - ) - ): - # The old reader cannot be used anymore, close it. - self._close_connection(self._reader_connection) + self._close_reader_if_necessary() self._in_read_write_split = True if not self._is_connection_usable(self._reader_connection, driver_dialect): @@ -356,7 +305,7 @@ def _switch_to_reader_connection(self): self._reader_host_info.url, ) - ReadWriteSplittingConnectionManager.close_connection(self._reader_connection, driver_dialect) + AbstractReadWriteSplittingPlugin.close_connection(self._reader_connection, driver_dialect) self._reader_connection = None self._reader_host_info = None self._initialize_reader_connection() @@ -364,42 +313,6 @@ def _switch_to_reader_connection(self): if self._is_writer_conn_from_internal_pool: self._close_connection(self._writer_connection) - def _initialize_reader_connection(self): - if self._connection_handler.has_no_readers(): - if not self._is_connection_usable( - self._writer_connection, self._plugin_service.driver_dialect - ): - self._initialize_writer_connection() - logger.warning( - "ReadWriteSplittingPlugin.NoReadersFound", self._writer_host_info.url - ) - return - - conn, reader_host = self._connection_handler.open_new_reader_connection(lambda x: self._plugin_service.connect(x, self._properties, self)) - - if conn is None or reader_host is None: - self.log_and_raise_exception("ReadWriteSplittingPlugin.NoReadersAvailable") - return - - logger.debug( - "ReadWriteSplittingPlugin.SuccessfullyConnectedToReader", reader_host.url - ) - - provider = self._conn_provider_manager.get_connection_provider( - reader_host, self._properties - ) - self._is_reader_conn_from_internal_pool = ( - ReadWriteSplittingConnectionManager._POOL_PROVIDER_CLASS_NAME - in str(type(provider)) - ) - - self._set_reader_connection(conn, reader_host) - self._switch_current_connection_to(conn, reader_host) - - logger.debug( - "ReadWriteSplittingPlugin.SwitchedFromWriterToReader", reader_host.url - ) - def _close_connection(self, internal_conn: Optional[Connection], close_only_if_idle: bool = True): if internal_conn is None: return @@ -454,89 +367,63 @@ def close_connection(conn: Optional[Connection], driver_dialect: DriverDialect): # Swallow exception return + # --- Abstract methods that concrete subclasses must implement --- -class ReadWriteConnectionHandler(Protocol): - """Protocol for handling writer/reader connection logic.""" - - @property - def host_list_provider_service(self) -> Optional[HostListProviderService]: - """Getter for the 'host_list_provider_service' attribute.""" - ... - - @host_list_provider_service.setter - def host_list_provider_service(self, new_value: HostListProviderService) -> None: - """The setter for the 'host_list_provider_service' attribute.""" - ... - - def open_new_writer_connection( - self, - plugin_service_connect_func: Callable[[HostInfo], Connection], - ) -> tuple[Optional[Connection], Optional[HostInfo]]: - """Open a writer connection.""" - ... - - def open_new_reader_connection( - self, - plugin_service_connect_func: Callable[[HostInfo], Connection], - ) -> tuple[Optional[Connection], Optional[HostInfo]]: - """Open a reader connection.""" - ... - - def get_verified_initial_connection( - self, - host_info: HostInfo, - is_initial_connection: bool, - plugin_service_connect_func: Callable[[HostInfo], Connection], - connect_func: Callable, - ) -> Connection: - """Verify initial connection or return normal workflow.""" - ... - - def should_update_writer_with_current_conn( - self, current_conn: Connection, current_host: HostInfo, writer_conn: Connection + @abstractmethod + def _should_update_writer_connection( + self, current_conn: Connection, current_host: HostInfo ) -> bool: - """Return true if the current connection fits the criteria of a writer connection.""" + """Return True if the current connection should update the cached writer connection.""" ... - def should_update_reader_with_current_conn( - self, current_conn: Connection, current_host: HostInfo, reader_conn: Connection + @abstractmethod + def _should_update_reader_connection( + self, current_conn: Connection, current_host: HostInfo ) -> bool: - """Return true if the current connection fits the criteria of a reader connection.""" + """Return True if the current connection should update the cached reader connection.""" ... - def is_writer_host(self, current_host: HostInfo) -> bool: - """Return true if the current host fits the criteria of a writer host.""" + @abstractmethod + def _is_writer(self, current_host: HostInfo) -> bool: + """Return True if the given host is a writer.""" ... - def is_reader_host(self, current_host: HostInfo) -> bool: - """Return true if the current host fits the criteria of a reader host.""" + @abstractmethod + def _is_reader(self, current_host: HostInfo) -> bool: + """Return True if the given host is a reader.""" ... - def can_host_be_used(self, host_info: HostInfo) -> bool: - """Returns true if connections can be switched to the given host""" + @abstractmethod + def _refresh_and_store_topology(self, current_conn: Optional[Connection]): + """Refresh the host list and store it for later use.""" ... - def has_no_readers(self) -> bool: - """Return true if there are no readers in the host list""" + @abstractmethod + def _initialize_writer_connection(self): + """Open a new writer connection and set it as the current connection.""" ... - def refresh_and_store_host_list( - self, current_conn: Optional[Connection], driver_dialect: DriverDialect - ): - """Refreshes the host list and then stores it.""" + @abstractmethod + def _initialize_reader_connection(self): + """Open a new reader connection and set it as the current connection.""" ... - def get_writer_host_info(self) -> Optional[HostInfo]: - """Get the current writer host info.""" + @abstractmethod + def _close_reader_if_necessary(self): + """Close the cached reader connection if it can no longer be used.""" ... -class TopologyBasedConnectionHandler(ReadWriteConnectionHandler): - """Topology based implementation of connection handling logic.""" +class ReadWriteSplittingPlugin(AbstractReadWriteSplittingPlugin): + """Topology-based read/write splitting plugin. + + Uses the host list topology to determine writer/reader roles and + select reader hosts via a configurable strategy. + """ def __init__(self, plugin_service: PluginService, props: Properties): - self._plugin_service: PluginService = plugin_service - self._host_list_provider_service: Optional[HostListProviderService] = None + # The read/write splitting plugin handles connections based on topology. + super().__init__(plugin_service, props) strategy = WrapperProperties.READER_HOST_SELECTOR_STRATEGY.get(props) if strategy is not None: self._reader_selector_strategy = strategy @@ -548,55 +435,13 @@ def __init__(self, plugin_service: PluginService, props: Properties): self._reader_selector_strategy = default_strategy self._hosts: Tuple[HostInfo, ...] = () - @property - def host_list_provider_service(self) -> Optional[HostListProviderService]: - return self._host_list_provider_service - - @host_list_provider_service.setter - def host_list_provider_service(self, new_value: HostListProviderService) -> None: - self._host_list_provider_service = new_value - - def open_new_writer_connection( - self, - plugin_service_connect_func: Callable[[HostInfo], Connection], - ) -> tuple[Optional[Connection], Optional[HostInfo]]: - writer_host = self.get_writer_host_info() - if writer_host is None: - return None, None - - conn = plugin_service_connect_func(writer_host) - - return conn, writer_host - - def open_new_reader_connection( - self, - plugin_service_connect_func: Callable[[HostInfo], Connection], - ) -> tuple[Optional[Connection], Optional[HostInfo]]: - conn: Optional[Connection] = None - reader_host: Optional[HostInfo] = None - - conn_attempts = len(self._plugin_service.hosts) * 2 - for _ in range(conn_attempts): - host = self._plugin_service.get_host_info_by_strategy( - HostRole.READER, self._reader_selector_strategy - ) - if host is not None: - try: - conn = plugin_service_connect_func(host) - reader_host = host - break - except Exception: - logger.warning( - "ReadWriteSplittingPlugin.FailedToConnectToReader", host.url - ) - - return conn, reader_host - - def get_verified_initial_connection( + def connect( self, + target_driver_func: Callable, + driver_dialect: DriverDialect, host_info: HostInfo, + props: Properties, is_initial_connection: bool, - plugin_service_connect_func: Callable[[HostInfo], Connection], connect_func: Callable, ) -> Connection: if not self._plugin_service.accepts_strategy( @@ -619,7 +464,7 @@ def get_verified_initial_connection( current_role = self._plugin_service.get_host_role(current_conn) if current_role is None or current_role == HostRole.UNKNOWN: - ReadWriteSplittingConnectionManager.log_and_raise_exception( + AbstractReadWriteSplittingPlugin.log_and_raise_exception( "ReadWriteSplittingPlugin.ErrorVerifyingInitialHostSpecRole" ) @@ -637,62 +482,141 @@ def get_verified_initial_connection( return current_conn - def can_host_be_used(self, host_info: HostInfo) -> bool: - hosts = [host_info.get_host_and_port() for host_info in self._hosts] - return host_info.get_host_and_port() in hosts + def _is_writer(self, current_host: HostInfo) -> bool: + return current_host.role == HostRole.WRITER - def has_no_readers(self) -> bool: - if len(self._hosts) == 1: - return self.get_writer_host_info() is not None - return False + def _is_reader(self, current_host: HostInfo) -> bool: + return current_host.role == HostRole.READER - def refresh_and_store_host_list( - self, current_conn: Optional[Connection], driver_dialect: DriverDialect - ): + def _refresh_and_store_topology(self, current_conn: Optional[Connection]): + driver_dialect = self._plugin_service.driver_dialect if current_conn is not None and driver_dialect.can_execute_query(current_conn): try: self._plugin_service.refresh_host_list() except Exception: - pass # Swallow exception + pass hosts = self._plugin_service.hosts if hosts is None or len(hosts) == 0: - ReadWriteSplittingConnectionManager.log_and_raise_exception( + AbstractReadWriteSplittingPlugin.log_and_raise_exception( "ReadWriteSplittingPlugin.EmptyHostList" ) self._hosts = hosts + self._writer_host_info = self._get_writer_host_info() + + def _initialize_writer_connection(self): + writer_host = self._get_writer_host_info() + if writer_host is None: + self.log_and_raise_exception( + "ReadWriteSplittingPlugin.FailedToConnectToWriter" + ) + return + + conn = self._plugin_service.connect(writer_host, self._properties, self) + if conn is None: + self.log_and_raise_exception( + "ReadWriteSplittingPlugin.FailedToConnectToWriter" + ) + return + + provider = self._conn_provider_manager.get_connection_provider( + writer_host, self._properties + ) + self._is_writer_conn_from_internal_pool = ( + AbstractReadWriteSplittingPlugin._POOL_PROVIDER_CLASS_NAME + in str(type(provider)) + ) + self._set_writer_connection(conn, writer_host) + self._switch_current_connection_to(conn, writer_host) + + def _initialize_reader_connection(self): + if len(self._hosts) == 1 and self._get_writer_host_info() is not None: + if not self._is_connection_usable( + self._writer_connection, self._plugin_service.driver_dialect + ): + self._initialize_writer_connection() + logger.warning( + "ReadWriteSplittingPlugin.NoReadersFound", self._writer_host_info.url + ) + return + + conn, reader_host = self._open_new_reader_connection() + + if conn is None or reader_host is None: + self.log_and_raise_exception("ReadWriteSplittingPlugin.NoReadersAvailable") + return + + logger.debug( + "ReadWriteSplittingPlugin.SuccessfullyConnectedToReader", reader_host.url + ) + + provider = self._conn_provider_manager.get_connection_provider( + reader_host, self._properties + ) + self._is_reader_conn_from_internal_pool = ( + AbstractReadWriteSplittingPlugin._POOL_PROVIDER_CLASS_NAME + in str(type(provider)) + ) + + self._set_reader_connection(conn, reader_host) + self._switch_current_connection_to(conn, reader_host) + + logger.debug( + "ReadWriteSplittingPlugin.SwitchedFromWriterToReader", reader_host.url + ) + + def _close_reader_if_necessary(self): + # The old reader cannot be used anymore, close it. + if ( + self._reader_host_info is not None + and self._reader_connection is not None + and not self._can_host_be_used(self._reader_host_info) + ): + self._close_connection(self._reader_connection) - def should_update_writer_with_current_conn( - self, current_conn, current_host: HostInfo, writer_conn: Connection + def _should_update_writer_connection( + self, current_conn: Connection, current_host: HostInfo ) -> bool: - return self.is_writer_host(current_host) + return self._is_writer(current_host) - def should_update_reader_with_current_conn( - self, current_conn, current_host, reader_conn: Connection + def _should_update_reader_connection( + self, current_conn: Connection, current_host: HostInfo ) -> bool: return True - def is_writer_host(self, current_host: HostInfo) -> bool: - return current_host.role == HostRole.WRITER - - def is_reader_host(self, current_host) -> bool: - return current_host.role == HostRole.READER - - def get_writer_host_info(self) -> Optional[HostInfo]: + def _get_writer_host_info(self) -> Optional[HostInfo]: for host in self._hosts: if host.role == HostRole.WRITER: return host - return None + def _can_host_be_used(self, host_info: HostInfo) -> bool: + hosts = [h.get_host_and_port() for h in self._hosts] + return host_info.get_host_and_port() in hosts + + def _open_new_reader_connection( + self, + ) -> tuple[Optional[Connection], Optional[HostInfo]]: + conn: Optional[Connection] = None + reader_host: Optional[HostInfo] = None -class ReadWriteSplittingPlugin(ReadWriteSplittingConnectionManager): - def __init__(self, plugin_service: PluginService, props: Properties): - # The read/write splitting plugin handles connections based on topology. - connection_handler = TopologyBasedConnectionHandler(plugin_service, props) + conn_attempts = len(self._plugin_service.hosts) * 2 + for _ in range(conn_attempts): + host = self._plugin_service.get_host_info_by_strategy( + HostRole.READER, self._reader_selector_strategy + ) + if host is not None: + try: + conn = self._plugin_service.connect(host, self._properties, self) + reader_host = host + break + except Exception: + logger.warning( + "ReadWriteSplittingPlugin.FailedToConnectToReader", host.url + ) - super().__init__(plugin_service, props, connection_handler) + return conn, reader_host class ReadWriteSplittingPluginFactory(PluginFactory): diff --git a/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py b/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py index 544e8775..76bb1b8e 100644 --- a/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py +++ b/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py @@ -18,14 +18,13 @@ from typing import TYPE_CHECKING, Callable, Optional, Type, TypeVar from aws_advanced_python_wrapper.host_availability import HostAvailability -from aws_advanced_python_wrapper.read_write_splitting_plugin import ( - ReadWriteConnectionHandler, ReadWriteSplittingConnectionManager) +from aws_advanced_python_wrapper.read_write_splitting_plugin import \ + AbstractReadWriteSplittingPlugin from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType from aws_advanced_python_wrapper.utils.rds_utils import RdsUtils if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect - from aws_advanced_python_wrapper.host_list_provider import HostListProviderService from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService from aws_advanced_python_wrapper.utils.properties import Properties, WrapperProperty @@ -37,76 +36,57 @@ from aws_advanced_python_wrapper.utils.properties import WrapperProperties -class EndpointBasedConnectionHandler(ReadWriteConnectionHandler): - """Endpoint based implementation of connection handling logic.""" +class SimpleReadWriteSplittingPlugin(AbstractReadWriteSplittingPlugin): + """Endpoint-based read/write splitting plugin. + + Uses configured read and write endpoints to manage connection switching + rather than relying on topology information. + """ + + T = TypeVar('T') def __init__(self, plugin_service: PluginService, props: Properties): - read_endpoint: str = EndpointBasedConnectionHandler._verify_parameter( + # The simple read/write splitting plugin handles connections based on configuration parameter endpoints. + super().__init__(plugin_service, props) + + read_endpoint: str = SimpleReadWriteSplittingPlugin._verify_parameter( WrapperProperties.SRW_READ_ENDPOINT, props, str, required=True ) - write_endpoint: str = EndpointBasedConnectionHandler._verify_parameter( + write_endpoint: str = SimpleReadWriteSplittingPlugin._verify_parameter( WrapperProperties.SRW_WRITE_ENDPOINT, props, str, required=True ) - self._verify_new_connections: bool = EndpointBasedConnectionHandler._verify_parameter( + self._verify_new_connections: bool = SimpleReadWriteSplittingPlugin._verify_parameter( WrapperProperties.SRW_VERIFY_NEW_CONNECTIONS, props, bool ) if self._verify_new_connections: - self._connect_retry_timeout_ms: int = EndpointBasedConnectionHandler._verify_parameter( + self._connect_retry_timeout_ms: int = SimpleReadWriteSplittingPlugin._verify_parameter( WrapperProperties.SRW_CONNECT_RETRY_TIMEOUT_MS, props, int, lambda x: x > 0 ) - self._connect_retry_interval_ms: int = EndpointBasedConnectionHandler._verify_parameter( + self._connect_retry_interval_ms: int = SimpleReadWriteSplittingPlugin._verify_parameter( WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS, props, int, lambda x: x > 0 ) self._verify_initial_connection_type: Optional[HostRole] = ( - EndpointBasedConnectionHandler._parse_role( + SimpleReadWriteSplittingPlugin._parse_role( WrapperProperties.SRW_VERIFY_INITIAL_CONNECTION_TYPE.get(props) ) ) - self._plugin_service: PluginService = plugin_service self._rds_utils: RdsUtils = RdsUtils() - self._host_list_provider_service: Optional[HostListProviderService] = None self._write_endpoint_host_info: HostInfo = self._create_host_info(write_endpoint, HostRole.WRITER) self._read_endpoint_host_info: HostInfo = self._create_host_info(read_endpoint, HostRole.READER) self._write_endpoint = write_endpoint.casefold() self._read_endpoint = read_endpoint.casefold() - @property - def host_list_provider_service(self) -> Optional[HostListProviderService]: - return self._host_list_provider_service - - @host_list_provider_service.setter - def host_list_provider_service(self, new_value: HostListProviderService) -> None: - self._host_list_provider_service = new_value - - def open_new_writer_connection( - self, - plugin_service_connect_func: Callable[[HostInfo], Connection], - ) -> tuple[Optional[Connection], Optional[HostInfo]]: - if self._verify_new_connections: - return self._get_verified_connection(self._write_endpoint_host_info, HostRole.WRITER, plugin_service_connect_func), \ - self._write_endpoint_host_info - - return plugin_service_connect_func(self._write_endpoint_host_info), self._write_endpoint_host_info - - def open_new_reader_connection( - self, - plugin_service_connect_func: Callable[[HostInfo], Connection], - ) -> tuple[Optional[Connection], Optional[HostInfo]]: - if self._verify_new_connections: - return self._get_verified_connection(self._read_endpoint_host_info, HostRole.READER, plugin_service_connect_func), \ - self._read_endpoint_host_info - - return plugin_service_connect_func(self._read_endpoint_host_info), self._read_endpoint_host_info - - def get_verified_initial_connection( + def connect( self, + target_driver_func: Callable, + driver_dialect: DriverDialect, host_info: HostInfo, + props: Properties, is_initial_connection: bool, - plugin_service_connect_func: Callable[[HostInfo], Connection], connect_func: Callable, ) -> Connection: if not is_initial_connection or not self._verify_new_connections: @@ -121,12 +101,14 @@ def get_verified_initial_connection( or url_type == RdsUrlType.RDS_GLOBAL_WRITER_CLUSTER or self._verify_initial_connection_type == HostRole.WRITER ): - conn = self._get_verified_connection(host_info, HostRole.WRITER, plugin_service_connect_func, connect_func) + conn = self._get_verified_connection(host_info, HostRole.WRITER, + lambda x: self._plugin_service.connect(x, props, self), connect_func) elif ( url_type == RdsUrlType.RDS_READER_CLUSTER or self._verify_initial_connection_type == HostRole.READER ): - conn = self._get_verified_connection(host_info, HostRole.READER, plugin_service_connect_func, connect_func) + conn = self._get_verified_connection(host_info, HostRole.READER, + lambda x: self._plugin_service.connect(x, props, self), connect_func) if conn is None: conn = connect_func() @@ -134,10 +116,100 @@ def get_verified_initial_connection( self._set_initial_connection_host_info(host_info) return conn + def _is_writer(self, current_host: HostInfo) -> bool: + return ( + current_host.host.casefold() == self._write_endpoint + or current_host.url.casefold() == self._write_endpoint + ) + + def _is_reader(self, current_host: HostInfo) -> bool: + return ( + current_host.host.casefold() == self._read_endpoint + or current_host.url.casefold() == self._read_endpoint + ) + + def _refresh_and_store_topology(self, current_conn: Optional[Connection]): + # Endpoint-based connections do not require a host list. + return + + def _initialize_writer_connection(self): + if self._verify_new_connections: + conn = self._get_verified_connection( + self._write_endpoint_host_info, HostRole.WRITER, + lambda x: self._plugin_service.connect(x, self._properties, self)) + else: + conn = self._plugin_service.connect(self._write_endpoint_host_info, self._properties, self) + + if conn is None: + self.log_and_raise_exception( + "ReadWriteSplittingPlugin.FailedToConnectToWriter" + ) + return + + provider = self._conn_provider_manager.get_connection_provider( + self._write_endpoint_host_info, self._properties + ) + self._is_writer_conn_from_internal_pool = ( + AbstractReadWriteSplittingPlugin._POOL_PROVIDER_CLASS_NAME + in str(type(provider)) + ) + self._set_writer_connection(conn, self._write_endpoint_host_info) + self._switch_current_connection_to(conn, self._write_endpoint_host_info) + + def _initialize_reader_connection(self): + if self._verify_new_connections: + conn = self._get_verified_connection( + self._read_endpoint_host_info, HostRole.READER, + lambda x: self._plugin_service.connect(x, self._properties, self)) + else: + conn = self._plugin_service.connect(self._read_endpoint_host_info, self._properties, self) + + if conn is None: + self.log_and_raise_exception("ReadWriteSplittingPlugin.NoReadersAvailable") + return + + provider = self._conn_provider_manager.get_connection_provider( + self._read_endpoint_host_info, self._properties + ) + self._is_reader_conn_from_internal_pool = ( + AbstractReadWriteSplittingPlugin._POOL_PROVIDER_CLASS_NAME + in str(type(provider)) + ) + + self._set_reader_connection(conn, self._read_endpoint_host_info) + self._switch_current_connection_to(conn, self._read_endpoint_host_info) + + def _close_reader_if_necessary(self): + # Endpoint-based connections always connect to the reader endpoint regardless. + pass + + def _should_update_writer_connection( + self, current_conn: Connection, current_host: HostInfo + ) -> bool: + return ( + self._is_writer(current_host) + and current_conn != self._writer_connection + and ( + not self._verify_new_connections + or self._plugin_service.get_host_role(current_conn) == HostRole.WRITER + ) + ) + + def _should_update_reader_connection( + self, current_conn: Connection, current_host: HostInfo + ) -> bool: + return ( + self._is_reader(current_host) + and current_conn != self._reader_connection + and ( + not self._verify_new_connections + or self._plugin_service.get_host_role(current_conn) == HostRole.READER + ) + ) + def _set_initial_connection_host_info(self, host_info: HostInfo): if self._host_list_provider_service is None: return - self._host_list_provider_service.initial_connection_host_info = host_info def _get_verified_connection( @@ -167,78 +239,25 @@ def _get_verified_connection( actual_role = self._plugin_service.get_host_role(candidate_conn) if actual_role != role: - ReadWriteSplittingConnectionManager.close_connection(candidate_conn, self._plugin_service.driver_dialect) + AbstractReadWriteSplittingPlugin.close_connection(candidate_conn, self._plugin_service.driver_dialect) self._delay() continue return candidate_conn except Exception: - ReadWriteSplittingConnectionManager.close_connection(candidate_conn, self._plugin_service.driver_dialect) + AbstractReadWriteSplittingPlugin.close_connection(candidate_conn, self._plugin_service.driver_dialect) self._delay() return None - def can_host_be_used(self, host_info: HostInfo) -> bool: - # Assume that the host can always be used, no topology-based information to check. - return True - - def has_no_readers(self) -> bool: - # SetReadOnly(true) will always connect to the read_endpoint, regardless of number of readers. - return False - - def refresh_and_store_host_list( - self, current_conn: Optional[Connection], driver_dialect: DriverDialect - ): - # Endpoint based connections do not require a host list. - return - - def should_update_writer_with_current_conn( - self, current_conn: Connection, current_host: HostInfo, writer_conn: Connection - ) -> bool: - return ( - self.is_writer_host(current_host) - and current_conn != writer_conn - and ( - not self._verify_new_connections - or self._plugin_service.get_host_role(current_conn) == HostRole.WRITER - ) - ) - - def should_update_reader_with_current_conn( - self, current_conn: Connection, current_host: HostInfo, reader_conn: Connection - ) -> bool: - return ( - self.is_reader_host(current_host) - and current_conn != reader_conn - and ( - not self._verify_new_connections - or self._plugin_service.get_host_role(current_conn) == HostRole.READER - ) - ) - - def is_writer_host(self, current_host: HostInfo) -> bool: - return ( - current_host.host.casefold() == self._write_endpoint - or current_host.url.casefold() == self._write_endpoint - ) - - def is_reader_host(self, current_host: HostInfo) -> bool: - return ( - current_host.host.casefold() == self._read_endpoint - or current_host.url.casefold() == self._read_endpoint - ) - - def get_writer_host_info(self) -> Optional[HostInfo]: - return self._write_endpoint_host_info - def _create_host_info(self, endpoint: str, role: HostRole) -> HostInfo: endpoint = endpoint.strip() host = endpoint try: port = self._plugin_service.database_dialect.default_port if not self._plugin_service.current_host_info.is_port_specified() \ else self._plugin_service.current_host_info.port - except AwsWrapperError: # if current_host_info cannot be determined fallback to default port + except AwsWrapperError: port = self._plugin_service.database_dialect.default_port colon_index = endpoint.rfind(":") @@ -252,8 +271,6 @@ def _create_host_info(self, endpoint: str, role: HostRole) -> HostInfo: host=host, port=port, role=role, availability=HostAvailability.AVAILABLE ) - T = TypeVar('T') - @staticmethod def _verify_parameter(prop: WrapperProperty, props: Properties, expected_type: Type[T], validator=None, required=False): value = prop.get_type(props, expected_type) @@ -297,14 +314,6 @@ def _parse_role(role_str: Optional[str]) -> HostRole: ) -class SimpleReadWriteSplittingPlugin(ReadWriteSplittingConnectionManager): - def __init__(self, plugin_service: PluginService, props: Properties): - # The simple read/write splitting plugin handles connections based on configuration parameter endpoints. - connection_handler = EndpointBasedConnectionHandler(plugin_service, props) - - super().__init__(plugin_service, props, connection_handler) - - class SimpleReadWriteSplittingPluginFactory(PluginFactory): @staticmethod def get_instance(plugin_service, props: Properties): diff --git a/tests/unit/test_read_write_splitting_plugin.py b/tests/unit/test_read_write_splitting_plugin.py index 0308d565..ab24cde5 100644 --- a/tests/unit/test_read_write_splitting_plugin.py +++ b/tests/unit/test_read_write_splitting_plugin.py @@ -131,7 +131,7 @@ def plugin_service_mock(mocker, driver_dialect_mock, writer_conn_mock): @pytest.fixture def read_write_splitting_plugin(plugin_service_mock, props, host_list_provider_service_mock): plugin = ReadWriteSplittingPlugin(plugin_service_mock, props) - plugin._connection_handler._host_list_provider_service = host_list_provider_service_mock + plugin._host_list_provider_service = host_list_provider_service_mock return plugin @@ -139,7 +139,7 @@ def read_write_splitting_plugin(plugin_service_mock, props, host_list_provider_s @pytest.fixture def srw_plugin(plugin_service_mock, srw_props, host_list_provider_service_mock): plugin = SimpleReadWriteSplittingPlugin(plugin_service_mock, srw_props) - plugin._connection_handler._host_list_provider_service = host_list_provider_service_mock + plugin._host_list_provider_service = host_list_provider_service_mock return plugin @@ -688,7 +688,7 @@ def test_connect_verification_fails_fallback_srw( plugin_service_mock.get_host_role.return_value = HostRole.READER # Wrong role plugin = SimpleReadWriteSplittingPlugin(plugin_service_mock, props) - plugin._connection_handler.host_list_provider_service = ( + plugin._host_list_provider_service = ( host_list_provider_service_mock ) @@ -738,7 +738,7 @@ def test_connect_non_rds_cluster_endpoint_with_verification_srw( ) plugin = SimpleReadWriteSplittingPlugin(plugin_service_mock, props) - plugin._connection_handler.host_list_provider_service = ( + plugin._host_list_provider_service = ( host_list_provider_service_mock )