Skip to content

Commit 09c295d

Browse files
MasterPtatoNathanFlurry
authored andcommitted
fix(gateway): fix gc logic
1 parent 4e625cf commit 09c295d

File tree

2 files changed

+135
-44
lines changed

2 files changed

+135
-44
lines changed

engine/packages/pegboard-gateway/src/lib.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use tokio_tungstenite::tungstenite::{
2525
protocol::frame::{CloseFrame, coding::CloseCode},
2626
};
2727

28-
use crate::shared_state::{SharedState, TunnelMessageData};
28+
use crate::shared_state::{InFlightRequestHandle, SharedState, TunnelMessageData};
2929

3030
pub mod shared_state;
3131

@@ -145,7 +145,10 @@ impl CustomServeTrait for PegboardGateway {
145145

146146
// Start listening for request responses
147147
let request_id = Uuid::new_v4().into_bytes();
148-
let mut msg_rx = self
148+
let InFlightRequestHandle {
149+
mut msg_rx,
150+
drop_rx: _drop_rx,
151+
} = self
149152
.shared_state
150153
.start_in_flight_request(tunnel_subject, request_id)
151154
.await;
@@ -258,7 +261,10 @@ impl CustomServeTrait for PegboardGateway {
258261

259262
// Start listening for WebSocket messages
260263
let request_id = unique_request_id.into_bytes();
261-
let mut msg_rx = self
264+
let InFlightRequestHandle {
265+
mut msg_rx,
266+
drop_rx: _drop_rx,
267+
} = self
262268
.shared_state
263269
.start_in_flight_request(tunnel_subject.clone(), request_id)
264270
.await;

engine/packages/pegboard-gateway/src/shared_state.rs

Lines changed: 126 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::{
77
sync::Arc,
88
time::{Duration, Instant},
99
};
10-
use tokio::sync::mpsc;
10+
use tokio::sync::{mpsc, watch};
1111
use universalpubsub::{NextOutput, PubSub, PublishOpts, Subscriber};
1212
use vbare::OwnedVersionedData;
1313

@@ -22,15 +22,27 @@ pub enum TunnelMessageData {
2222
Timeout,
2323
}
2424

25+
pub struct InFlightRequestHandle {
26+
pub msg_rx: mpsc::Receiver<TunnelMessageData>,
27+
/// Used to check if the request handler has been dropped.
28+
///
29+
/// This is separate from `msg_rx` there may still be messages that need to be sent to the
30+
/// request after `msg_rx` has dropped.
31+
pub drop_rx: watch::Receiver<()>,
32+
}
33+
2534
struct InFlightRequest {
2635
/// UPS subject to send messages to for this request.
2736
receiver_subject: String,
2837
/// Sender for incoming messages to this request.
2938
msg_tx: mpsc::Sender<TunnelMessageData>,
39+
/// Used to check if the request handler has been dropped.
40+
drop_tx: watch::Sender<()>,
3041
/// True once first message for this request has been sent (so runner learned reply_to).
3142
opened: bool,
3243
pending_msgs: Vec<PendingMessage>,
3344
hibernation_state: Option<HibernationState>,
45+
stopping: bool,
3446
}
3547

3648
pub struct PendingMessage {
@@ -87,28 +99,37 @@ impl SharedState {
8799
&self,
88100
receiver_subject: String,
89101
request_id: RequestId,
90-
) -> mpsc::Receiver<TunnelMessageData> {
102+
) -> InFlightRequestHandle {
91103
let (msg_tx, msg_rx) = mpsc::channel(128);
104+
let (drop_tx, drop_rx) = watch::channel(());
92105

93106
match self.in_flight_requests.entry_async(request_id).await {
94107
Entry::Vacant(entry) => {
95108
entry.insert_entry(InFlightRequest {
96109
receiver_subject,
97110
msg_tx,
111+
drop_tx,
98112
opened: false,
99113
pending_msgs: Vec::new(),
100114
hibernation_state: None,
115+
stopping: false,
101116
});
102117
}
103118
Entry::Occupied(mut entry) => {
104119
entry.receiver_subject = receiver_subject;
105120
entry.msg_tx = msg_tx;
121+
entry.drop_tx = drop_tx;
106122
entry.opened = false;
107123
entry.pending_msgs.clear();
124+
125+
if entry.stopping {
126+
entry.hibernation_state = None;
127+
entry.stopping = false;
128+
}
108129
}
109130
}
110131

111-
msg_rx
132+
InFlightRequestHandle { msg_rx, drop_rx }
112133
}
113134

114135
pub async fn send_message(
@@ -401,64 +422,128 @@ impl SharedState {
401422
}
402423

403424
async fn gc(&self) {
425+
let mut interval = tokio::time::interval(GC_INTERVAL);
426+
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
427+
428+
loop {
429+
interval.tick().await;
430+
431+
self.gc_in_flight_requests().await;
432+
}
433+
}
434+
435+
/// This will remove all in flight requests that are cancelled or had an ack timeout.
436+
///
437+
/// Purging requests is done in a 2 phase commit in order to ensure that the InFlightRequest is
438+
/// kept until the ToClientWebSocketClose message has been successfully sent.
439+
///
440+
/// If we did not use a 2 phase commit (i.e. a single `retain` for any GC purge), the
441+
/// InFlightRequest would be removed immediately and the runner would never receive the
442+
/// ToClientWebSocketClose.
443+
///
444+
/// **Phase 1**
445+
///
446+
/// 1a. Find requests that need to be purged (either closed by gateway or message acknowledgement took too long)
447+
/// 1b. Flag the request as `stopping` to prevent re-purging this request in the next GC tick
448+
/// 1c. Send a `Timeout` message to `msg_tx` which will terminate the task in `handle_websocket`
449+
/// 1d. Once both tasks terminate, `handle_websocket` sends the `ToClientWebSocketClose` to the in flight request
450+
/// 1e. `handle_websocket` exits and drops `drop_rx`
451+
///
452+
/// **Phase 2**
453+
///
454+
/// 2a. Remove all requests where it was flagged as stopping and `drop_rx` has been dropped
455+
async fn gc_in_flight_requests(&self) {
404456
#[derive(Debug)]
405457
enum MsgGcReason {
458+
/// Gateway channel is closed and there are no pending messages
459+
GatewayClosed,
406460
/// Any tunnel message not acked (TunnelAck)
407461
MessageNotAcked,
408462
/// WebSocket pending messages (ToServerWebSocketMessageAck)
409463
WebSocketMessageNotAcked,
410464
}
411465

412-
let mut interval = tokio::time::interval(GC_INTERVAL);
413-
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
466+
let now = Instant::now();
414467

415-
loop {
416-
interval.tick().await;
468+
// First, check if an in flight req is beyond the timeout for tunnel message ack and websocket
469+
// message ack
470+
self.in_flight_requests
471+
.iter_mut_async(|mut entry| {
472+
let (request_id, req) = &mut *entry;
417473

418-
let now = Instant::now();
474+
if req.stopping {
475+
return true;
476+
}
419477

420-
self.in_flight_requests
421-
.retain_async(|request_id, req| {
422-
if req.msg_tx.is_closed() {
423-
return false;
478+
let reason = 'reason: {
479+
// If we have no pending messages of any kind and the channel is closed, remove the
480+
// in flight req
481+
if req.msg_tx.is_closed()
482+
&& req.pending_msgs.is_empty()
483+
&& req
484+
.hibernation_state
485+
.as_ref()
486+
.map(|hs| hs.pending_ws_msgs.is_empty())
487+
.unwrap_or(true)
488+
{
489+
break 'reason Some(MsgGcReason::GatewayClosed);
424490
}
425491

426-
let reason = 'reason: {
427-
if let Some(earliest_pending_msg) = req.pending_msgs.first() {
428-
if now.duration_since(earliest_pending_msg.send_instant)
429-
<= MESSAGE_ACK_TIMEOUT
430-
{
431-
break 'reason Some(MsgGcReason::MessageNotAcked);
432-
}
492+
if let Some(earliest_pending_msg) = req.pending_msgs.first() {
493+
if now.duration_since(earliest_pending_msg.send_instant)
494+
<= MESSAGE_ACK_TIMEOUT
495+
{
496+
break 'reason Some(MsgGcReason::MessageNotAcked);
433497
}
498+
}
434499

435-
if let Some(hs) = &req.hibernation_state
436-
&& let Some(earliest_pending_ws_msg) = hs.pending_ws_msgs.first()
500+
if let Some(hs) = &req.hibernation_state
501+
&& let Some(earliest_pending_ws_msg) = hs.pending_ws_msgs.first()
502+
{
503+
if now.duration_since(earliest_pending_ws_msg.send_instant)
504+
<= MESSAGE_ACK_TIMEOUT
437505
{
438-
if now.duration_since(earliest_pending_ws_msg.send_instant)
439-
<= MESSAGE_ACK_TIMEOUT
440-
{
441-
break 'reason Some(MsgGcReason::WebSocketMessageNotAcked);
442-
}
506+
break 'reason Some(MsgGcReason::WebSocketMessageNotAcked);
443507
}
508+
}
444509

445-
None
446-
};
510+
None
511+
};
447512

448-
if let Some(reason) = &reason {
449-
tracing::debug!(
450-
request_id=?Uuid::from_bytes(*request_id),
451-
?reason,
452-
"gc collecting in flight request"
453-
);
454-
let _ = req.msg_tx.send(TunnelMessageData::Timeout);
455-
}
513+
if let Some(reason) = &reason {
514+
tracing::debug!(
515+
request_id=?Uuid::from_bytes(*request_id),
516+
?reason,
517+
"gc stopping in flight request"
518+
);
456519

457-
// Return true if the request was not gc'd
458-
reason.is_none()
459-
})
460-
.await;
461-
}
520+
let _ = req.msg_tx.send(TunnelMessageData::Timeout);
521+
522+
// Mark req as stopping to skip this loop next time the gc is run
523+
req.stopping = true;
524+
}
525+
526+
true
527+
})
528+
.await;
529+
530+
self.in_flight_requests
531+
.retain_async(|request_id, req| {
532+
// The reason we check for stopping here is because drop_tx could be dropped if we are
533+
// between websocket retries (we don't want to remove the in flight req in this case).
534+
// When the websocket reconnects a new channel will be created
535+
if req.stopping && req.drop_tx.is_closed() {
536+
tracing::debug!(
537+
request_id=?Uuid::from_bytes(*request_id),
538+
"gc removing in flight request"
539+
);
540+
541+
return false;
542+
}
543+
544+
true
545+
})
546+
.await;
462547
}
463548
}
464549

0 commit comments

Comments
 (0)