Skip to content

Commit b73d035

Browse files
committed
fix: boot dupe runners
1 parent a81bd73 commit b73d035

File tree

12 files changed

+284
-105
lines changed

12 files changed

+284
-105
lines changed

engine/packages/pegboard-gateway/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,8 +486,8 @@ impl CustomServeTrait for PegboardGateway {
486486
// Determine single result from both tasks
487487
let mut lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res) {
488488
// Prefer error
489-
(_, Err(err)) => Err(err),
490489
(Err(err), _) => Err(err),
490+
(_, Err(err)) => Err(err),
491491
// Prefer non aborted result if both succeed
492492
(Ok(res), Ok(LifecycleResult::Aborted)) => Ok(res),
493493
(Ok(LifecycleResult::Aborted), Ok(res)) => Ok(res),

engine/packages/pegboard-runner/src/conn.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,13 @@ pub struct TunnelActiveRequest {
2424
}
2525

2626
pub struct Conn {
27+
pub namespace_id: Id,
28+
pub runner_name: String,
29+
pub runner_key: String,
2730
pub runner_id: Id,
28-
2931
pub workflow_id: Id,
30-
3132
pub protocol_version: u16,
32-
3333
pub ws_handle: WebSocketHandle,
34-
3534
pub last_rtt: AtomicU32,
3635

3736
/// Active HTTP & WebSocket requests. They are separate but use the same mechanism to
@@ -63,7 +62,7 @@ pub async fn init_conn(
6362
let mut ws_rx = ws_rx.lock().await;
6463

6564
// Receive init packet
66-
let (runner_id, workflow_id) = if let Some(msg) =
65+
let (runner_name, runner_id, workflow_id) = if let Some(msg) =
6766
tokio::time::timeout(Duration::from_secs(5), ws_rx.next())
6867
.await
6968
.map_err(|_| WsError::TimedOutWaitingForInit.build())?
@@ -81,7 +80,7 @@ pub async fn init_conn(
8180
.map_err(|err| WsError::InvalidPacket(err.to_string()).build())
8281
.context("failed to deserialize initial packet from client")?;
8382

84-
let (runner_id, workflow_id) =
83+
let (runner_name, runner_id, workflow_id) =
8584
if let protocol::ToServer::ToServerInit(protocol::ToServerInit {
8685
name,
8786
version,
@@ -160,7 +159,7 @@ pub async fn init_conn(
160159
)
161160
})?;
162161

163-
(runner_id, workflow_id)
162+
(name.clone(), runner_id, workflow_id)
164163
} else {
165164
tracing::debug!(?packet, "invalid initial packet");
166165
return Err(WsError::InvalidInitialPacket("must be `ToServer::Init`").build());
@@ -178,12 +177,15 @@ pub async fn init_conn(
178177
)
179178
})?;
180179

181-
(runner_id, workflow_id)
180+
(runner_name, runner_id, workflow_id)
182181
} else {
183182
return Err(WsError::ConnectionClosed.build());
184183
};
185184

186185
Ok(Arc::new(Conn {
186+
namespace_id: namespace.namespace_id,
187+
runner_name,
188+
runner_key,
187189
runner_id,
188190
workflow_id,
189191
protocol_version,

engine/packages/pegboard-runner/src/lib.rs

Lines changed: 126 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,26 @@ use rivet_guard_core::{
1111
};
1212
use rivet_runner_protocol as protocol;
1313
use std::time::Duration;
14+
use tokio::sync::watch;
1415
use tokio_tungstenite::tungstenite::protocol::frame::{CloseFrame, coding::CloseCode};
1516
use universalpubsub::PublishOpts;
1617
use vbare::OwnedVersionedData;
1718

18-
mod client_to_pubsub_task;
1919
mod conn;
2020
mod errors;
2121
mod ping_task;
22-
mod pubsub_to_client_task;
22+
mod tunnel_to_ws_task;
2323
mod utils;
24+
mod ws_to_tunnel_task;
2425

2526
const UPDATE_PING_INTERVAL: Duration = Duration::from_secs(3);
2627

28+
#[derive(Debug)]
29+
enum LifecycleResult {
30+
Closed,
31+
Aborted,
32+
}
33+
2734
pub struct PegboardRunnerWsCustomServe {
2835
ctx: StandaloneCtx,
2936
}
@@ -79,52 +86,142 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe {
7986
.await
8087
.context("failed to initialize runner connection")?;
8188

82-
// Subscribe to pubsub topic for this runner before accepting the client websocket so
83-
// that failures can be retried by the proxy.
89+
// Subscribe before accepting the client websocket so that failures can be retried by the proxy.
8490
let topic =
8591
pegboard::pubsub_subjects::RunnerReceiverSubject::new(conn.runner_id).to_string();
86-
tracing::debug!(%topic, "subscribing to runner receiver topic");
92+
let eviction_topic =
93+
pegboard::pubsub_subjects::RunnerEvictionByIdSubject::new(conn.runner_id).to_string();
94+
let eviction_topic2 = pegboard::pubsub_subjects::RunnerEvictionByNameSubject::new(
95+
conn.namespace_id,
96+
&conn.runner_name,
97+
&conn.runner_key,
98+
)
99+
.to_string();
100+
101+
tracing::debug!(%topic, %eviction_topic, %eviction_topic2, "subscribing to runner topics");
87102
let sub = ups
88103
.subscribe(&topic)
89104
.await
90105
.with_context(|| format!("failed to subscribe to runner receiver topic: {}", topic))?;
106+
let mut eviction_sub = ups.subscribe(&eviction_topic).await.with_context(|| {
107+
format!(
108+
"failed to subscribe to runner eviction topic: {}",
109+
eviction_topic
110+
)
111+
})?;
112+
let mut eviction_sub2 = ups.subscribe(&eviction_topic2).await.with_context(|| {
113+
format!(
114+
"failed to subscribe to runner eviction topic: {}",
115+
eviction_topic2
116+
)
117+
})?;
118+
119+
// Publish eviction message to evict any currently connected runners with the same id or ns id +
120+
// runner name + runner key. This happens after subscribing to prevent race conditions.
121+
tokio::try_join!(
122+
async {
123+
ups.publish(&eviction_topic, &[], PublishOpts::broadcast())
124+
.await?;
125+
// Because we will receive our own message, skip the first message in the sub
126+
eviction_sub.next().await
127+
},
128+
async {
129+
ups.publish(&eviction_topic2, &[], PublishOpts::broadcast())
130+
.await?;
131+
eviction_sub2.next().await
132+
},
133+
)?;
91134

92-
// Forward pubsub -> WebSocket
93-
let mut pubsub_to_client = tokio::spawn(pubsub_to_client_task::task(conn.clone(), sub));
135+
let (tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch::channel(());
136+
let (ws_to_tunnel_abort_tx, ws_to_tunnel_abort_rx) = watch::channel(());
137+
let (ping_abort_tx, ping_abort_rx) = watch::channel(());
138+
139+
let tunnel_to_ws = tokio::spawn(tunnel_to_ws_task::task(
140+
conn.clone(),
141+
sub,
142+
eviction_sub,
143+
tunnel_to_ws_abort_rx,
144+
));
94145

95-
// Forward WebSocket -> pubsub
96-
let mut client_to_pubsub = tokio::spawn(client_to_pubsub_task::task(
146+
let ws_to_tunnel = tokio::spawn(ws_to_tunnel_task::task(
97147
self.ctx.clone(),
98148
conn.clone(),
99149
ws_handle.recv(),
150+
eviction_sub2,
151+
ws_to_tunnel_abort_rx,
100152
));
101153

102154
// Update pings
103-
let mut ping = tokio::spawn(ping_task::task(self.ctx.clone(), conn.clone()));
155+
let ping = tokio::spawn(ping_task::task(
156+
self.ctx.clone(),
157+
conn.clone(),
158+
ping_abort_rx,
159+
));
160+
let tunnel_to_ws_abort_tx2 = tunnel_to_ws_abort_tx.clone();
161+
let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx.clone();
162+
let ping_abort_tx2 = ping_abort_tx.clone();
163+
164+
// Wait for all tasks to complete
165+
let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) = tokio::join!(
166+
async {
167+
let res = tunnel_to_ws.await?;
168+
169+
// Abort others if not aborted
170+
if !matches!(res, Ok(LifecycleResult::Aborted)) {
171+
tracing::debug!(?res, "tunnel to ws task completed, aborting others");
172+
173+
let _ = ping_abort_tx.send(());
174+
let _ = ws_to_tunnel_abort_tx.send(());
175+
} else {
176+
tracing::debug!(?res, "tunnel to ws task completed");
177+
}
104178

105-
// Wait for either task to complete
106-
let lifecycle_res = tokio::select! {
107-
res = &mut pubsub_to_client => {
108-
let res = res?;
109-
tracing::debug!(?res, "pubsub to WebSocket task completed");
110179
res
111-
}
112-
res = &mut client_to_pubsub => {
113-
let res = res?;
114-
tracing::debug!(?res, "WebSocket to pubsub task completed");
180+
},
181+
async {
182+
let res = ws_to_tunnel.await?;
183+
184+
// Abort others if not aborted
185+
if !matches!(res, Ok(LifecycleResult::Aborted)) {
186+
tracing::debug!(?res, "ws to tunnel task completed, aborting others");
187+
188+
let _ = ping_abort_tx2.send(());
189+
let _ = tunnel_to_ws_abort_tx.send(());
190+
} else {
191+
tracing::debug!(?res, "ws to tunnel task completed");
192+
}
193+
115194
res
116-
}
117-
res = &mut ping => {
118-
let res = res?;
119-
tracing::debug!(?res, "ping task completed");
195+
},
196+
async {
197+
let res = ping.await?;
198+
199+
// Abort others if not aborted
200+
if !matches!(res, Ok(LifecycleResult::Aborted)) {
201+
tracing::debug!(?res, "ping task completed, aborting others");
202+
203+
let _ = ws_to_tunnel_abort_tx2.send(());
204+
let _ = tunnel_to_ws_abort_tx2.send(());
205+
} else {
206+
tracing::debug!(?res, "ping task completed");
207+
}
208+
120209
res
121210
}
122-
};
211+
);
123212

124-
// Abort remaining tasks
125-
pubsub_to_client.abort();
126-
client_to_pubsub.abort();
127-
ping.abort();
213+
// Determine single result from both tasks
214+
let lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) {
215+
// Prefer error
216+
(Err(err), _, _) => Err(err),
217+
(_, Err(err), _) => Err(err),
218+
(_, _, Err(err)) => Err(err),
219+
// Prefer non aborted result if both succeed
220+
(Ok(res), Ok(LifecycleResult::Aborted), _) => Ok(res),
221+
(Ok(LifecycleResult::Aborted), Ok(res), _) => Ok(res),
222+
// Unlikely case
223+
(res, _, _) => res,
224+
};
128225

129226
// Make runner immediately ineligible when it disconnects
130227
let update_alloc_res = self
@@ -177,10 +274,7 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe {
177274
.context("failed to serialize tunnel message for gateway")?;
178275

179276
// Publish message to UPS
180-
let res = self
181-
.ctx
182-
.ups()
183-
.context("failed to get UPS instance for tunnel message")?
277+
let res = ups
184278
.publish(&req.gateway_reply_to, &msg_serialized, PublishOpts::one())
185279
.await;
186280

engine/packages/pegboard-runner/src/ping_task.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
use gas::prelude::*;
22
use pegboard::ops::runner::update_alloc_idx::{Action, RunnerEligibility};
33
use std::sync::{Arc, atomic::Ordering};
4+
use tokio::sync::watch;
45

5-
use crate::{UPDATE_PING_INTERVAL, conn::Conn};
6+
use crate::{LifecycleResult, UPDATE_PING_INTERVAL, conn::Conn};
67

78
/// Updates the ping of all runners requesting a ping update at once.
89
#[tracing::instrument(skip_all)]
9-
pub async fn task(ctx: StandaloneCtx, conn: Arc<Conn>) -> Result<()> {
10+
pub async fn task(
11+
ctx: StandaloneCtx,
12+
conn: Arc<Conn>,
13+
mut ping_abort_rx: watch::Receiver<()>,
14+
) -> Result<LifecycleResult> {
1015
loop {
11-
tokio::time::sleep(UPDATE_PING_INTERVAL).await;
16+
tokio::select! {
17+
_ = tokio::time::sleep(UPDATE_PING_INTERVAL) => {}
18+
_ = ping_abort_rx.changed() => {
19+
return Ok(LifecycleResult::Aborted);
20+
}
21+
}
1222

1323
let Some(wf) = ctx
1424
.workflow::<pegboard::workflows::runner::Input>(conn.workflow_id)

engine/packages/pegboard-runner/src/pubsub_to_client_task.rs renamed to engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,43 @@ use gas::prelude::*;
33
use hyper_tungstenite::tungstenite::Message as WsMessage;
44
use rivet_runner_protocol::{self as protocol, versioned};
55
use std::sync::Arc;
6+
use tokio::sync::watch;
67
use universalpubsub::{NextOutput, Subscriber};
78
use vbare::OwnedVersionedData;
89

910
use crate::{
11+
LifecycleResult,
1012
conn::{Conn, TunnelActiveRequest},
1113
errors,
1214
};
1315

1416
#[tracing::instrument(skip_all, fields(runner_id=?conn.runner_id, workflow_id=?conn.workflow_id, protocol_version=%conn.protocol_version))]
15-
pub async fn task(conn: Arc<Conn>, mut sub: Subscriber) -> Result<()> {
16-
while let NextOutput::Message(ups_msg) = sub
17-
.next()
18-
.await
19-
.context("pubsub_to_client_task sub failed")?
20-
{
17+
pub async fn task(
18+
conn: Arc<Conn>,
19+
mut sub: Subscriber,
20+
mut eviction_sub: Subscriber,
21+
mut tunnel_to_ws_abort_rx: watch::Receiver<()>,
22+
) -> Result<LifecycleResult> {
23+
loop {
24+
let ups_msg = tokio::select! {
25+
res = sub.next() => {
26+
if let NextOutput::Message(ups_msg) = res.context("pubsub_to_client_task sub failed")? {
27+
ups_msg
28+
} else {
29+
tracing::debug!("tunnel sub closed");
30+
bail!("tunnel sub closed");
31+
}
32+
}
33+
_ = eviction_sub.next() => {
34+
tracing::debug!("runner evicted");
35+
return Err(errors::WsError::Eviction.build());
36+
}
37+
_ = tunnel_to_ws_abort_rx.changed() => {
38+
tracing::debug!("task aborted");
39+
return Ok(LifecycleResult::Aborted);
40+
}
41+
};
42+
2143
tracing::debug!(
2244
payload_len = ups_msg.payload.len(),
2345
"received message from pubsub, forwarding to WebSocket"
@@ -105,6 +127,4 @@ pub async fn task(conn: Arc<Conn>, mut sub: Subscriber) -> Result<()> {
105127
.await
106128
.context("failed to send message to WebSocket")?
107129
}
108-
109-
Ok(())
110130
}

0 commit comments

Comments
 (0)