@@ -38,10 +38,19 @@ impl<CF: ClientFactory> InstanceState<CF> {
3838 . ok_or_else ( || v4:: Error :: ConnectionFailed ( "no connection found" . into ( ) ) )
3939 }
4040
41- async fn is_address_allowed ( & self , address : & str ) -> Result < bool > {
42- let Ok ( config) = address. parse :: < tokio_postgres:: Config > ( ) else {
43- return Ok ( false ) ;
44- } ;
41+ #[ allow( clippy:: result_large_err) ]
42+ async fn ensure_address_allowed ( & self , address : & str ) -> Result < ( ) , v4:: Error > {
43+ fn conn_failed ( message : impl Into < String > ) -> v4:: Error {
44+ v4:: Error :: ConnectionFailed ( message. into ( ) )
45+ }
46+ fn err_other ( err : anyhow:: Error ) -> v4:: Error {
47+ v4:: Error :: Other ( err. to_string ( ) )
48+ }
49+
50+ let config = address
51+ . parse :: < tokio_postgres:: Config > ( )
52+ . map_err ( |e| conn_failed ( e. to_string ( ) ) ) ?;
53+
4554 for ( i, host) in config. get_hosts ( ) . iter ( ) . enumerate ( ) {
4655 match host {
4756 tokio_postgres:: config:: Host :: Tcp ( address) => {
@@ -55,15 +64,24 @@ impl<CF: ClientFactory> InstanceState<CF> {
5564 . or_else ( || if ports. len ( ) == 1 { ports. get ( 1 ) } else { None } ) ;
5665 let port_str = port. map ( |p| format ! ( ":{p}" ) ) . unwrap_or_default ( ) ;
5766 let url = format ! ( "{address}{port_str}" ) ;
58- if !self . allowed_hosts . check_url ( & url, "postgres" ) . await ? {
59- return Ok ( false ) ;
67+ if !self
68+ . allowed_hosts
69+ . check_url ( & url, "postgres" )
70+ . await
71+ . map_err ( err_other) ?
72+ {
73+ return Err ( conn_failed ( format ! (
74+ "address postgres://{url} is not permitted"
75+ ) ) ) ;
6076 }
6177 }
6278 #[ cfg( unix) ]
63- tokio_postgres:: config:: Host :: Unix ( _) => return Ok ( false ) ,
79+ tokio_postgres:: config:: Host :: Unix ( _) => {
80+ return Err ( conn_failed ( "Unix sockets are not supported on WebAssembly" ) ) ;
81+ }
6482 }
6583 }
66- Ok ( true )
84+ Ok ( ( ) )
6785 }
6886}
6987
@@ -82,15 +100,8 @@ impl<CF: ClientFactory> v3::HostConnection for InstanceState<CF> {
82100 async fn open ( & mut self , address : String ) -> Result < Resource < v3:: Connection > , v3:: Error > {
83101 spin_factor_outbound_networking:: record_address_fields ( & address) ;
84102
85- if !self
86- . is_address_allowed ( & address)
87- . await
88- . map_err ( |e| v3:: Error :: Other ( e. to_string ( ) ) ) ?
89- {
90- return Err ( v3:: Error :: ConnectionFailed ( format ! (
91- "address {address} is not permitted"
92- ) ) ) ;
93- }
103+ self . ensure_address_allowed ( & address) . await ?;
104+
94105 Ok ( self . open_connection ( & address) . await ?)
95106 }
96107
@@ -134,15 +145,8 @@ impl<CF: ClientFactory> v4::HostConnection for InstanceState<CF> {
134145 async fn open ( & mut self , address : String ) -> Result < Resource < v4:: Connection > , v4:: Error > {
135146 spin_factor_outbound_networking:: record_address_fields ( & address) ;
136147
137- if !self
138- . is_address_allowed ( & address)
139- . await
140- . map_err ( |e| v4:: Error :: Other ( e. to_string ( ) ) ) ?
141- {
142- return Err ( v4:: Error :: ConnectionFailed ( format ! (
143- "address {address} is not permitted"
144- ) ) ) ;
145- }
148+ self . ensure_address_allowed ( & address) . await ?;
149+
146150 self . open_connection ( & address) . await
147151 }
148152
@@ -199,11 +203,7 @@ impl<CF: ClientFactory> v4::Host for InstanceState<CF> {
199203/// Delegate a function call to the v3::HostConnection implementation
200204macro_rules! delegate {
201205 ( $self: ident. $name: ident( $address: expr, $( $arg: expr) ,* ) ) => { {
202- if !$self. is_address_allowed( & $address) . await . map_err( |e| v4:: Error :: Other ( e. to_string( ) ) ) ? {
203- return Err ( v1:: PgError :: ConnectionFailed ( format!(
204- "address {} is not permitted" , $address
205- ) ) ) ;
206- }
206+ $self. ensure_address_allowed( & $address) . await ?;
207207 let connection = match $self. open_connection( & $address) . await {
208208 Ok ( c) => c,
209209 Err ( e) => return Err ( e. into( ) ) ,
@@ -221,15 +221,8 @@ impl<CF: ClientFactory> v2::HostConnection for InstanceState<CF> {
221221 async fn open ( & mut self , address : String ) -> Result < Resource < v2:: Connection > , v2:: Error > {
222222 spin_factor_outbound_networking:: record_address_fields ( & address) ;
223223
224- if !self
225- . is_address_allowed ( & address)
226- . await
227- . map_err ( |e| v2:: Error :: Other ( e. to_string ( ) ) ) ?
228- {
229- return Err ( v2:: Error :: ConnectionFailed ( format ! (
230- "address {address} is not permitted"
231- ) ) ) ;
232- }
224+ self . ensure_address_allowed ( & address) . await ?;
225+
233226 Ok ( self . open_connection ( & address) . await ?)
234227 }
235228
0 commit comments