@@ -7,7 +7,7 @@ use std::{
77 sync:: Arc ,
88 time:: { Duration , Instant } ,
99} ;
10- use tokio:: sync:: mpsc;
10+ use tokio:: sync:: { mpsc, watch } ;
1111use universalpubsub:: { NextOutput , PubSub , PublishOpts , Subscriber } ;
1212use vbare:: OwnedVersionedData ;
1313
@@ -22,15 +22,27 @@ pub enum TunnelMessageData {
2222 Timeout ,
2323}
2424
25+ pub struct InFlightRequestHandle {
26+ pub msg_rx : mpsc:: Receiver < TunnelMessageData > ,
27+ /// Used to check if the request handler has been dropped.
28+ ///
29+ /// This is separate from `msg_rx` there may still be messages that need to be sent to the
30+ /// request after `msg_rx` has dropped.
31+ pub drop_rx : watch:: Receiver < ( ) > ,
32+ }
33+
2534struct InFlightRequest {
2635 /// UPS subject to send messages to for this request.
2736 receiver_subject : String ,
2837 /// Sender for incoming messages to this request.
2938 msg_tx : mpsc:: Sender < TunnelMessageData > ,
39+ /// Used to check if the request handler has been dropped.
40+ drop_tx : watch:: Sender < ( ) > ,
3041 /// True once first message for this request has been sent (so runner learned reply_to).
3142 opened : bool ,
3243 pending_msgs : Vec < PendingMessage > ,
3344 hibernation_state : Option < HibernationState > ,
45+ stopping : bool ,
3446}
3547
3648pub struct PendingMessage {
@@ -87,28 +99,37 @@ impl SharedState {
8799 & self ,
88100 receiver_subject : String ,
89101 request_id : RequestId ,
90- ) -> mpsc :: Receiver < TunnelMessageData > {
102+ ) -> InFlightRequestHandle {
91103 let ( msg_tx, msg_rx) = mpsc:: channel ( 128 ) ;
104+ let ( drop_tx, drop_rx) = watch:: channel ( ( ) ) ;
92105
93106 match self . in_flight_requests . entry_async ( request_id) . await {
94107 Entry :: Vacant ( entry) => {
95108 entry. insert_entry ( InFlightRequest {
96109 receiver_subject,
97110 msg_tx,
111+ drop_tx,
98112 opened : false ,
99113 pending_msgs : Vec :: new ( ) ,
100114 hibernation_state : None ,
115+ stopping : false ,
101116 } ) ;
102117 }
103118 Entry :: Occupied ( mut entry) => {
104119 entry. receiver_subject = receiver_subject;
105120 entry. msg_tx = msg_tx;
121+ entry. drop_tx = drop_tx;
106122 entry. opened = false ;
107123 entry. pending_msgs . clear ( ) ;
124+
125+ if entry. stopping {
126+ entry. hibernation_state = None ;
127+ entry. stopping = false ;
128+ }
108129 }
109130 }
110131
111- msg_rx
132+ InFlightRequestHandle { msg_rx, drop_rx }
112133 }
113134
114135 pub async fn send_message (
@@ -401,64 +422,128 @@ impl SharedState {
401422 }
402423
403424 async fn gc ( & self ) {
425+ let mut interval = tokio:: time:: interval ( GC_INTERVAL ) ;
426+ interval. set_missed_tick_behavior ( tokio:: time:: MissedTickBehavior :: Skip ) ;
427+
428+ loop {
429+ interval. tick ( ) . await ;
430+
431+ self . gc_in_flight_requests ( ) . await ;
432+ }
433+ }
434+
435+ /// This will remove all in flight requests that are cancelled or had an ack timeout.
436+ ///
437+ /// Purging requests is done in a 2 phase commit in order to ensure that the InFlightRequest is
438+ /// kept until the ToClientWebSocketClose message has been successfully sent.
439+ ///
440+ /// If we did not use a 2 phase commit (i.e. a single `retain` for any GC purge), the
441+ /// InFlightRequest would be removed immediately and the runner would never receive the
442+ /// ToClientWebSocketClose.
443+ ///
444+ /// **Phase 1**
445+ ///
446+ /// 1a. Find requests that need to be purged (either closed by gateway or message acknowledgement took too long)
447+ /// 1b. Flag the request as `stopping` to prevent re-purging this request in the next GC tick
448+ /// 1c. Send a `Timeout` message to `msg_tx` which will terminate the task in `handle_websocket`
449+ /// 1d. Once both tasks terminate, `handle_websocket` sends the `ToClientWebSocketClose` to the in flight request
450+ /// 1e. `handle_websocket` exits and drops `drop_rx`
451+ ///
452+ /// **Phase 2**
453+ ///
454+ /// 2a. Remove all requests where it was flagged as stopping and `drop_rx` has been dropped
455+ async fn gc_in_flight_requests ( & self ) {
404456 #[ derive( Debug ) ]
405457 enum MsgGcReason {
458+ /// Gateway channel is closed and there are no pending messages
459+ GatewayClosed ,
406460 /// Any tunnel message not acked (TunnelAck)
407461 MessageNotAcked ,
408462 /// WebSocket pending messages (ToServerWebSocketMessageAck)
409463 WebSocketMessageNotAcked ,
410464 }
411465
412- let mut interval = tokio:: time:: interval ( GC_INTERVAL ) ;
413- interval. set_missed_tick_behavior ( tokio:: time:: MissedTickBehavior :: Skip ) ;
466+ let now = Instant :: now ( ) ;
414467
415- loop {
416- interval. tick ( ) . await ;
468+ // First, check if an in flight req is beyond the timeout for tunnel message ack and websocket
469+ // message ack
470+ self . in_flight_requests
471+ . iter_mut_async ( |mut entry| {
472+ let ( request_id, req) = & mut * entry;
417473
418- let now = Instant :: now ( ) ;
474+ if req. stopping {
475+ return true ;
476+ }
419477
420- self . in_flight_requests
421- . retain_async ( |request_id, req| {
422- if req. msg_tx . is_closed ( ) {
423- return false ;
478+ let reason = ' reason: {
479+ // If we have no pending messages of any kind and the channel is closed, remove the
480+ // in flight req
481+ if req. msg_tx . is_closed ( )
482+ && req. pending_msgs . is_empty ( )
483+ && req
484+ . hibernation_state
485+ . as_ref ( )
486+ . map ( |hs| hs. pending_ws_msgs . is_empty ( ) )
487+ . unwrap_or ( true )
488+ {
489+ break ' reason Some ( MsgGcReason :: GatewayClosed ) ;
424490 }
425491
426- let reason = ' reason: {
427- if let Some ( earliest_pending_msg) = req. pending_msgs . first ( ) {
428- if now. duration_since ( earliest_pending_msg. send_instant )
429- <= MESSAGE_ACK_TIMEOUT
430- {
431- break ' reason Some ( MsgGcReason :: MessageNotAcked ) ;
432- }
492+ if let Some ( earliest_pending_msg) = req. pending_msgs . first ( ) {
493+ if now. duration_since ( earliest_pending_msg. send_instant )
494+ <= MESSAGE_ACK_TIMEOUT
495+ {
496+ break ' reason Some ( MsgGcReason :: MessageNotAcked ) ;
433497 }
498+ }
434499
435- if let Some ( hs) = & req. hibernation_state
436- && let Some ( earliest_pending_ws_msg) = hs. pending_ws_msgs . first ( )
500+ if let Some ( hs) = & req. hibernation_state
501+ && let Some ( earliest_pending_ws_msg) = hs. pending_ws_msgs . first ( )
502+ {
503+ if now. duration_since ( earliest_pending_ws_msg. send_instant )
504+ <= MESSAGE_ACK_TIMEOUT
437505 {
438- if now. duration_since ( earliest_pending_ws_msg. send_instant )
439- <= MESSAGE_ACK_TIMEOUT
440- {
441- break ' reason Some ( MsgGcReason :: WebSocketMessageNotAcked ) ;
442- }
506+ break ' reason Some ( MsgGcReason :: WebSocketMessageNotAcked ) ;
443507 }
508+ }
444509
445- None
446- } ;
510+ None
511+ } ;
447512
448- if let Some ( reason) = & reason {
449- tracing:: debug!(
450- request_id=?Uuid :: from_bytes( * request_id) ,
451- ?reason,
452- "gc collecting in flight request"
453- ) ;
454- let _ = req. msg_tx . send ( TunnelMessageData :: Timeout ) ;
455- }
513+ if let Some ( reason) = & reason {
514+ tracing:: debug!(
515+ request_id=?Uuid :: from_bytes( * request_id) ,
516+ ?reason,
517+ "gc stopping in flight request"
518+ ) ;
456519
457- // Return true if the request was not gc'd
458- reason. is_none ( )
459- } )
460- . await ;
461- }
520+ let _ = req. msg_tx . send ( TunnelMessageData :: Timeout ) ;
521+
522+ // Mark req as stopping to skip this loop next time the gc is run
523+ req. stopping = true ;
524+ }
525+
526+ true
527+ } )
528+ . await ;
529+
530+ self . in_flight_requests
531+ . retain_async ( |request_id, req| {
532+ // The reason we check for stopping here is because drop_tx could be dropped if we are
533+ // between websocket retries (we don't want to remove the in flight req in this case).
534+ // When the websocket reconnects a new channel will be created
535+ if req. stopping && req. drop_tx . is_closed ( ) {
536+ tracing:: debug!(
537+ request_id=?Uuid :: from_bytes( * request_id) ,
538+ "gc removing in flight request"
539+ ) ;
540+
541+ return false ;
542+ }
543+
544+ true
545+ } )
546+ . await ;
462547 }
463548}
464549
0 commit comments