@@ -31,6 +31,7 @@ struct InFlightRequest {
3131 opened : bool ,
3232 pending_msgs : Vec < PendingMessage > ,
3333 hibernation_state : Option < HibernationState > ,
34+ stopping : bool ,
3435}
3536
3637pub struct PendingMessage {
@@ -98,13 +99,19 @@ impl SharedState {
9899 opened : false ,
99100 pending_msgs : Vec :: new ( ) ,
100101 hibernation_state : None ,
102+ stopping : false ,
101103 } ) ;
102104 }
103105 Entry :: Occupied ( mut entry) => {
104106 entry. receiver_subject = receiver_subject;
105107 entry. msg_tx = msg_tx;
106108 entry. opened = false ;
107109 entry. pending_msgs . clear ( ) ;
110+
111+ if entry. stopping {
112+ entry. hibernation_state = None ;
113+ entry. stopping = false ;
114+ }
108115 }
109116 }
110117
@@ -403,6 +410,8 @@ impl SharedState {
403410 async fn gc ( & self ) {
404411 #[ derive( Debug ) ]
405412 enum MsgGcReason {
413+ /// Gateway channel is closed and there are no pending messages
414+ GatewayClosed ,
406415 /// Any tunnel message not acked (TunnelAck)
407416 MessageNotAcked ,
408417 /// WebSocket pending messages (ToServerWebSocketMessageAck)
@@ -417,13 +426,30 @@ impl SharedState {
417426
418427 let now = Instant :: now ( ) ;
419428
429+ // First, check if an in flight req is beyond the timeout for tunnel message ack and websocket
430+ // message ack
420431 self . in_flight_requests
421- . retain_async ( |request_id, req| {
422- if req. msg_tx . is_closed ( ) {
423- return false ;
432+ . iter_mut_async ( |mut entry| {
433+ let ( request_id, req) = & mut * entry;
434+
435+ if req. stopping {
436+ return true ;
424437 }
425438
426439 let reason = ' reason: {
440+ // If we have no pending messages of any kind and the channel is closed, remove the
441+ // in flight req
442+ if req. msg_tx . is_closed ( )
443+ && req. pending_msgs . is_empty ( )
444+ && req
445+ . hibernation_state
446+ . as_ref ( )
447+ . map ( |hs| hs. pending_ws_msgs . is_empty ( ) )
448+ . unwrap_or ( true )
449+ {
450+ break ' reason Some ( MsgGcReason :: GatewayClosed ) ;
451+ }
452+
427453 if let Some ( earliest_pending_msg) = req. pending_msgs . first ( ) {
428454 if now. duration_since ( earliest_pending_msg. send_instant )
429455 <= MESSAGE_ACK_TIMEOUT
@@ -449,13 +475,34 @@ impl SharedState {
449475 tracing:: debug!(
450476 request_id=?Uuid :: from_bytes( * request_id) ,
451477 ?reason,
452- "gc collecting in flight request"
478+ "gc stopping in flight request"
453479 ) ;
480+
454481 let _ = req. msg_tx . send ( TunnelMessageData :: Timeout ) ;
482+
483+ // Mark req as stopping to skip this loop next time the gc is run
484+ req. stopping = true ;
485+ }
486+
487+ true
488+ } )
489+ . await ;
490+
491+ self . in_flight_requests
492+ . retain_async ( |request_id, req| {
493+ // The reason we check for stopping here is because msg_rx could be dropped if we are
494+ // between websocket retries (we don't want to remove the in flight req in this case).
495+ // When the websocket reconnects a new channel will be created
496+ if req. stopping && req. msg_tx . is_closed ( ) {
497+ tracing:: debug!(
498+ request_id=?Uuid :: from_bytes( * request_id) ,
499+ "gc removing in flight request"
500+ ) ;
501+
502+ return false ;
455503 }
456504
457- // Return true if the request was not gc'd
458- reason. is_none ( )
505+ true
459506 } )
460507 . await ;
461508 }
0 commit comments