@@ -63,6 +63,15 @@ def clear_cache():
6363 SqlAlchemyPooledConnectionProvider ._database_pools .clear ()
6464
6565
66+ @pytest .fixture
67+ def mock_dialects (mocker ):
68+ mock_driver_dialect = mocker .MagicMock ()
69+ mock_database_dialect = mocker .MagicMock ()
70+ mock_driver_dialect .prepare_connect_info .side_effect = lambda host , props : Properties (props .copy ())
71+ mock_database_dialect .prepare_conn_props .return_value = None
72+ return mock_driver_dialect , mock_database_dialect
73+
74+
6675def test_connect__default_mapping__default_pool_configuration (provider , host_info , mocker , mock_conn , mock_pool ):
6776 expected_urls = {host_info .url }
6877 expected_keys = [PoolKey (host_info .url , "user1" )]
@@ -100,6 +109,69 @@ def test_connect__custom_configuration_and_mapping(host_info, mocker, mock_conn,
100109 mock_pool_initializer_func .assert_called_with (creator = mock_pool_connection_func , pool_size = 10 )
101110
102111
112+ def test_connect__updates_password_in_cached_pool_creator_props (host_info , mocker , mock_dialects ):
113+ captured_password : list = []
114+ props = Properties ({WrapperProperties .USER .name : "user1" , WrapperProperties .PASSWORD .name : "TOKEN_1" })
115+
116+ def fake_target_connect (** kwargs ):
117+ captured_password .append (kwargs .get ("password" ))
118+ return mocker .MagicMock (spec = psycopg .Connection )
119+
120+ provider = SqlAlchemyPooledConnectionProvider (
121+ pool_configurator = lambda _ , __ : {"pool_size" : 0 , "max_overflow" : 2 }
122+ )
123+ mock_driver_dialect , mock_database_dialect = mock_dialects
124+
125+ # Create a cached pool
126+ conn_1 = provider .connect (fake_target_connect , mock_driver_dialect , mock_database_dialect , host_info , props )
127+ assert conn_1 is not None
128+
129+ # Rotate password
130+ props [WrapperProperties .PASSWORD .name ] = "TOKEN_2"
131+
132+ conn_2 = provider .connect (fake_target_connect , mock_driver_dialect , mock_database_dialect , host_info , props )
133+ assert conn_2 is not None
134+
135+ assert captured_password == ["TOKEN_1" , "TOKEN_2" ]
136+
137+
138+ def test_connect__password_update_different_pool_keys (host_info , mocker , mock_dialects ):
139+ captured_password_user1 : list = []
140+ captured_password_user2 : list = []
141+ props_user1 = Properties ({WrapperProperties .USER .name : "user1" , WrapperProperties .PASSWORD .name : "TOKEN_1" })
142+ props_user2 = Properties ({WrapperProperties .USER .name : "user2" , WrapperProperties .PASSWORD .name : "TOKEN_2" })
143+
144+ def fake_target_connect (** kwargs ):
145+ user = kwargs .get ("user" )
146+ pwd = kwargs .get ("password" )
147+ if user == "user1" :
148+ captured_password_user1 .append (pwd )
149+ else :
150+ captured_password_user2 .append (pwd )
151+ return mocker .MagicMock (spec = psycopg .Connection )
152+
153+ provider = SqlAlchemyPooledConnectionProvider (
154+ pool_configurator = lambda _ , __ : {"pool_size" : 0 , "max_overflow" : 3 }
155+ )
156+ mock_driver_dialect , mock_database_dialect = mock_dialects
157+
158+ conn1 = provider .connect (fake_target_connect , mock_driver_dialect , mock_database_dialect , host_info , props_user1 )
159+ conn2 = provider .connect (fake_target_connect , mock_driver_dialect , mock_database_dialect , host_info , props_user2 )
160+ assert conn1 is not None
161+ assert conn2 is not None
162+
163+ assert captured_password_user1 == ["TOKEN_1" ]
164+ assert captured_password_user2 == ["TOKEN_2" ]
165+
166+ # Rotate password for user 1
167+ props_user1 [WrapperProperties .PASSWORD .name ] = "TOKEN_3"
168+ conn3 = provider .connect (fake_target_connect , mock_driver_dialect , mock_database_dialect , host_info , props_user1 )
169+ assert conn3 is not None
170+
171+ assert captured_password_user1 == ["TOKEN_1" , "TOKEN_3" ]
172+ assert captured_password_user2 == ["TOKEN_2" ]
173+
174+
103175def test_accepts_host_info (provider ):
104176 instance_url = "instance-1.XYZ.us-east-2.rds.amazonaws.com"
105177 instance_host_info = HostInfo (instance_url )
@@ -115,18 +187,24 @@ def test_least_connections_strategy(provider, mock_pool):
115187 writer = HostInfo ("writer.XYZ.us-east-2.rds.amazonaws.com" )
116188 reader_1 = HostInfo ("reader-1.XYZ.us-east-2.rds.amazonaws.com" , role = HostRole .READER )
117189 reader_2 = HostInfo ("reader-2.XYZ.us-east-2.rds.amazonaws.com" , role = HostRole .READER )
118- hosts = [ writer , reader_1 , reader_2 ]
190+ hosts = ( writer , reader_1 , reader_2 )
119191 props = Properties ({WrapperProperties .USER .name : "user1" , WrapperProperties .PASSWORD .name : "password" })
120192
121193 # Create cache with 1 pool to reader_url_1_connection and 2 pools to reader_url_2_connections.
122194 # Each pool holds 1 connection.
123195 test_database_pools = SlidingExpirationCache ()
124196 test_database_pools .compute_if_absent (
125- PoolKey (reader_1 .url , "user1" ), lambda _ : mock_pool , 10 * 60_000_000_000 )
197+ PoolKey (reader_1 .url , "user1" ),
198+ lambda _ : (mock_pool , Properties ()),
199+ 10 * 60_000_000_000 )
126200 test_database_pools .compute_if_absent (
127- PoolKey (reader_2 .url , "user1" ), lambda _ : mock_pool , 10 * 60_000_000_000 )
201+ PoolKey (reader_2 .url , "user1" ),
202+ lambda _ : (mock_pool , Properties ()),
203+ 10 * 60_000_000_000 )
128204 test_database_pools .compute_if_absent (
129- PoolKey (reader_2 .url , "user2" ), lambda _ : mock_pool , 10 * 60_000_000_000 )
205+ PoolKey (reader_2 .url , "user2" ),
206+ lambda _ : (mock_pool , Properties ()),
207+ 10 * 60_000_000_000 )
130208
131209 result = provider .get_host_info_by_strategy (hosts , HostRole .READER , "least_connections" , props )
132210 assert reader_1 == result
@@ -135,15 +213,21 @@ def test_least_connections_strategy(provider, mock_pool):
135213def test_least_connections_strategy__no_hosts_matching_role (provider ):
136214 props = Properties ()
137215 with pytest .raises (AwsWrapperError ):
138- provider .get_host_info_by_strategy ([ HostInfo ("writer" )] , HostRole .READER , "least_connections" , props )
216+ provider .get_host_info_by_strategy (( HostInfo ("writer" ),) , HostRole .READER , "least_connections" , props )
139217
140218
141219def test_release_resources (provider , mocker ):
142220 pool1 = mocker .MagicMock ()
143221 pool2 = mocker .MagicMock ()
144222 test_database_pools = SlidingExpirationCache ()
145- test_database_pools .compute_if_absent (PoolKey ("url1" , "user1" ), lambda _ : pool1 , 60_000_000_000 )
146- test_database_pools .compute_if_absent (PoolKey ("url1" , "user2" ), lambda _ : pool2 , 60_000_000_000 )
223+ test_database_pools .compute_if_absent (
224+ PoolKey ("url1" , "user1" ),
225+ lambda _ : (pool1 , Properties ()),
226+ 60_000_000_000 )
227+ test_database_pools .compute_if_absent (
228+ PoolKey ("url1" , "user2" ),
229+ lambda _ : (pool2 , Properties ()),
230+ 60_000_000_000 )
147231 SqlAlchemyPooledConnectionProvider ._database_pools = test_database_pools
148232
149233 provider .release_resources ()
0 commit comments