@@ -8,10 +8,11 @@ use pegboard::actor_kv;
88use pegboard:: pubsub_subjects:: GatewayReceiverSubject ;
99use rivet_envoy_protocol:: { self as protocol, PROTOCOL_VERSION , versioned} ;
1010use rivet_guard_core:: websocket_handle:: WebSocketReceiver ;
11+ use scc:: HashMap ;
1112use std:: sync:: { Arc , atomic:: Ordering } ;
1213use tokio:: sync:: { Mutex , MutexGuard , watch} ;
1314use universaldb:: utils:: end_of_key_range;
14- use universalpubsub:: PublishOpts ;
15+ use universalpubsub:: { PubSub , PublishOpts } ;
1516use vbare:: OwnedVersionedData ;
1617
1718use crate :: { LifecycleResult , actor_event_demuxer:: ActorEventDemuxer , conn:: Conn , errors} ;
@@ -366,7 +367,7 @@ async fn handle_message(
366367 }
367368 }
368369 protocol:: ToRivet :: ToRivetTunnelMessage ( tunnel_msg) => {
369- handle_tunnel_message ( & ctx, tunnel_msg)
370+ handle_tunnel_message ( & ctx, conn , tunnel_msg)
370371 . await
371372 . context ( "failed to handle tunnel message" ) ?;
372373 }
@@ -447,16 +448,41 @@ async fn ack_commands(
447448#[ tracing:: instrument( skip_all) ]
448449async fn handle_tunnel_message (
449450 ctx : & StandaloneCtx ,
451+ conn : & Conn ,
452+ msg : protocol:: ToRivetTunnelMessage ,
453+ ) -> Result < ( ) > {
454+ forward_tunnel_message (
455+ & ctx. ups ( ) . context ( "failed to get UPS instance for tunnel message" ) ?,
456+ ctx. config ( ) . pegboard ( ) . envoy_max_response_payload_size ( ) ,
457+ & conn. authorized_tunnel_routes ,
458+ msg,
459+ )
460+ . await
461+ }
462+
463+ async fn forward_tunnel_message (
464+ ups : & PubSub ,
465+ max_payload_size : usize ,
466+ authorized_tunnel_routes : & HashMap < ( protocol:: GatewayId , protocol:: RequestId ) , ( ) > ,
450467 msg : protocol:: ToRivetTunnelMessage ,
451468) -> Result < ( ) > {
452469 // Extract inner data length before consuming msg
453470 let inner_data_len = tunnel_message_inner_data_len ( & msg. message_kind ) ;
454471
455472 // Enforce incoming payload size
456- if inner_data_len > ctx . config ( ) . pegboard ( ) . envoy_max_response_payload_size ( ) {
473+ if inner_data_len > max_payload_size {
457474 return Err ( errors:: WsError :: InvalidPacket ( "payload too large" . to_string ( ) ) . build ( ) ) ;
458475 }
459476
477+ if !authorized_tunnel_routes
478+ . contains_async ( & ( msg. message_id . gateway_id , msg. message_id . request_id ) )
479+ . await
480+ {
481+ return Err (
482+ errors:: WsError :: InvalidPacket ( "unauthorized tunnel message" . to_string ( ) ) . build ( ) ,
483+ ) ;
484+ }
485+
460486 let gateway_reply_to = GatewayReceiverSubject :: new ( msg. message_id . gateway_id ) . to_string ( ) ;
461487 let msg_serialized =
462488 versioned:: ToGateway :: wrap_latest ( protocol:: ToGateway :: ToRivetTunnelMessage ( msg) )
@@ -470,9 +496,7 @@ async fn handle_tunnel_message(
470496 ) ;
471497
472498 // Publish message to UPS
473- ctx. ups ( )
474- . context ( "failed to get UPS instance for tunnel message" ) ?
475- . publish ( & gateway_reply_to, & msg_serialized, PublishOpts :: one ( ) )
499+ ups. publish ( & gateway_reply_to, & msg_serialized, PublishOpts :: one ( ) )
476500 . await
477501 . with_context ( || {
478502 format ! (
@@ -500,6 +524,86 @@ fn tunnel_message_inner_data_len(kind: &protocol::ToRivetTunnelMessageKind) -> u
500524 }
501525}
502526
527+ #[ cfg( test) ]
528+ mod tests {
529+ use std:: sync:: Arc ;
530+ use std:: time:: Duration ;
531+
532+ use super :: * ;
533+ use universalpubsub:: driver:: memory:: MemoryDriver ;
534+ use universalpubsub:: NextOutput ;
535+
536+ fn test_pubsub ( channel : & str ) -> PubSub {
537+ PubSub :: new ( Arc :: new ( MemoryDriver :: new ( channel. to_string ( ) ) ) )
538+ }
539+
540+ fn test_message ( gateway_id : [ u8 ; 4 ] , request_id : [ u8 ; 4 ] ) -> protocol:: ToRivetTunnelMessage {
541+ protocol:: ToRivetTunnelMessage {
542+ message_id : protocol:: MessageId {
543+ gateway_id,
544+ request_id,
545+ message_index : 0 ,
546+ } ,
547+ message_kind : protocol:: ToRivetTunnelMessageKind :: ToRivetResponseAbort ,
548+ }
549+ }
550+
551+ #[ tokio:: test]
552+ async fn rejects_unissued_tunnel_message_pairs ( ) {
553+ let pubsub = test_pubsub ( "pegboard-envoy-ws-to-tunnel-test-reject" ) ;
554+ let gateway_id = [ 1 , 2 , 3 , 4 ] ;
555+ let request_id = [ 5 , 6 , 7 , 8 ] ;
556+ let mut sub = pubsub
557+ . subscribe ( & GatewayReceiverSubject :: new ( gateway_id) . to_string ( ) )
558+ . await
559+ . unwrap ( ) ;
560+ let authorized_tunnel_routes = HashMap :: new ( ) ;
561+
562+ let err = forward_tunnel_message (
563+ & pubsub,
564+ 1024 ,
565+ & authorized_tunnel_routes,
566+ test_message ( gateway_id, request_id) ,
567+ )
568+ . await
569+ . unwrap_err ( ) ;
570+ assert ! ( err. to_string( ) . contains( "unauthorized tunnel message" ) ) ;
571+
572+ let recv = tokio:: time:: timeout ( Duration :: from_millis ( 50 ) , sub. next ( ) ) . await ;
573+ assert ! ( recv. is_err( ) ) ;
574+ }
575+
576+ #[ tokio:: test]
577+ async fn republishes_issued_tunnel_message_pairs ( ) {
578+ let pubsub = test_pubsub ( "pegboard-envoy-ws-to-tunnel-test-allow" ) ;
579+ let gateway_id = [ 9 , 10 , 11 , 12 ] ;
580+ let request_id = [ 13 , 14 , 15 , 16 ] ;
581+ let mut sub = pubsub
582+ . subscribe ( & GatewayReceiverSubject :: new ( gateway_id) . to_string ( ) )
583+ . await
584+ . unwrap ( ) ;
585+ let authorized_tunnel_routes = HashMap :: new ( ) ;
586+ let _ = authorized_tunnel_routes
587+ . insert_async ( ( gateway_id, request_id) , ( ) )
588+ . await ;
589+
590+ forward_tunnel_message (
591+ & pubsub,
592+ 1024 ,
593+ & authorized_tunnel_routes,
594+ test_message ( gateway_id, request_id) ,
595+ )
596+ . await
597+ . unwrap ( ) ;
598+
599+ let msg = tokio:: time:: timeout ( Duration :: from_secs ( 1 ) , sub. next ( ) )
600+ . await
601+ . unwrap ( )
602+ . unwrap ( ) ;
603+ assert ! ( matches!( msg, NextOutput :: Message ( _) ) ) ;
604+ }
605+ }
606+
503607async fn send_actor_kv_error ( conn : & Conn , request_id : u32 , message : & str ) -> Result < ( ) > {
504608 let res_msg = versioned:: ToEnvoy :: wrap_latest ( protocol:: ToEnvoy :: ToEnvoyKvResponse (
505609 protocol:: ToEnvoyKvResponse {
0 commit comments