@@ -11,19 +11,26 @@ use rivet_guard_core::{
1111} ;
1212use rivet_runner_protocol as protocol;
1313use std:: time:: Duration ;
14+ use tokio:: sync:: watch;
1415use tokio_tungstenite:: tungstenite:: protocol:: frame:: { CloseFrame , coding:: CloseCode } ;
1516use universalpubsub:: PublishOpts ;
1617use vbare:: OwnedVersionedData ;
1718
18- mod client_to_pubsub_task;
1919mod conn;
2020mod errors;
2121mod ping_task;
22- mod pubsub_to_client_task ;
22+ mod tunnel_to_ws_task ;
2323mod utils;
24+ mod ws_to_tunnel_task;
2425
2526const UPDATE_PING_INTERVAL : Duration = Duration :: from_secs ( 3 ) ;
2627
28+ #[ derive( Debug ) ]
29+ enum LifecycleResult {
30+ Closed ,
31+ Aborted ,
32+ }
33+
2734pub struct PegboardRunnerWsCustomServe {
2835 ctx : StandaloneCtx ,
2936}
@@ -79,52 +86,142 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe {
7986 . await
8087 . context ( "failed to initialize runner connection" ) ?;
8188
82- // Subscribe to pubsub topic for this runner before accepting the client websocket so
83- // that failures can be retried by the proxy.
89+ // Subscribe before accepting the client websocket so that failures can be retried by the proxy.
8490 let topic =
8591 pegboard:: pubsub_subjects:: RunnerReceiverSubject :: new ( conn. runner_id ) . to_string ( ) ;
86- tracing:: debug!( %topic, "subscribing to runner receiver topic" ) ;
92+ let eviction_topic =
93+ pegboard:: pubsub_subjects:: RunnerEvictionByIdSubject :: new ( conn. runner_id ) . to_string ( ) ;
94+ let eviction_topic2 = pegboard:: pubsub_subjects:: RunnerEvictionByNameSubject :: new (
95+ conn. namespace_id ,
96+ & conn. runner_name ,
97+ & conn. runner_key ,
98+ )
99+ . to_string ( ) ;
100+
101+ tracing:: debug!( %topic, %eviction_topic, %eviction_topic2, "subscribing to runner topics" ) ;
87102 let sub = ups
88103 . subscribe ( & topic)
89104 . await
90105 . with_context ( || format ! ( "failed to subscribe to runner receiver topic: {}" , topic) ) ?;
106+ let mut eviction_sub = ups. subscribe ( & eviction_topic) . await . with_context ( || {
107+ format ! (
108+ "failed to subscribe to runner eviction topic: {}" ,
109+ eviction_topic
110+ )
111+ } ) ?;
112+ let mut eviction_sub2 = ups. subscribe ( & eviction_topic2) . await . with_context ( || {
113+ format ! (
114+ "failed to subscribe to runner eviction topic: {}" ,
115+ eviction_topic2
116+ )
117+ } ) ?;
118+
119+ // Publish eviction message to evict any currently connected runners with the same id or ns id +
120+ // runner name + runner key. This happens after subscribing to prevent race conditions.
121+ tokio:: try_join!(
122+ async {
123+ ups. publish( & eviction_topic, & [ ] , PublishOpts :: broadcast( ) )
124+ . await ?;
125+ // Because we will receive our own message, skip the first message in the sub
126+ eviction_sub. next( ) . await
127+ } ,
128+ async {
129+ ups. publish( & eviction_topic2, & [ ] , PublishOpts :: broadcast( ) )
130+ . await ?;
131+ eviction_sub2. next( ) . await
132+ } ,
133+ ) ?;
91134
92- // Forward pubsub -> WebSocket
93- let mut pubsub_to_client = tokio:: spawn ( pubsub_to_client_task:: task ( conn. clone ( ) , sub) ) ;
135+ let ( tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch:: channel ( ( ) ) ;
136+ let ( ws_to_tunnel_abort_tx, ws_to_tunnel_abort_rx) = watch:: channel ( ( ) ) ;
137+ let ( ping_abort_tx, ping_abort_rx) = watch:: channel ( ( ) ) ;
138+
139+ let tunnel_to_ws = tokio:: spawn ( tunnel_to_ws_task:: task (
140+ conn. clone ( ) ,
141+ sub,
142+ eviction_sub,
143+ tunnel_to_ws_abort_rx,
144+ ) ) ;
94145
95- // Forward WebSocket -> pubsub
96- let mut client_to_pubsub = tokio:: spawn ( client_to_pubsub_task:: task (
146+ let ws_to_tunnel = tokio:: spawn ( ws_to_tunnel_task:: task (
97147 self . ctx . clone ( ) ,
98148 conn. clone ( ) ,
99149 ws_handle. recv ( ) ,
150+ eviction_sub2,
151+ ws_to_tunnel_abort_rx,
100152 ) ) ;
101153
102154 // Update pings
103- let mut ping = tokio:: spawn ( ping_task:: task ( self . ctx . clone ( ) , conn. clone ( ) ) ) ;
155+ let ping = tokio:: spawn ( ping_task:: task (
156+ self . ctx . clone ( ) ,
157+ conn. clone ( ) ,
158+ ping_abort_rx,
159+ ) ) ;
160+ let tunnel_to_ws_abort_tx2 = tunnel_to_ws_abort_tx. clone ( ) ;
161+ let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx. clone ( ) ;
162+ let ping_abort_tx2 = ping_abort_tx. clone ( ) ;
163+
164+ // Wait for all tasks to complete
165+ let ( tunnel_to_ws_res, ws_to_tunnel_res, ping_res) = tokio:: join!(
166+ async {
167+ let res = tunnel_to_ws. await ?;
168+
169+ // Abort others if not aborted
170+ if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
171+ tracing:: debug!( ?res, "tunnel to ws task completed, aborting others" ) ;
172+
173+ let _ = ping_abort_tx. send( ( ) ) ;
174+ let _ = ws_to_tunnel_abort_tx. send( ( ) ) ;
175+ } else {
176+ tracing:: debug!( ?res, "tunnel to ws task completed" ) ;
177+ }
104178
105- // Wait for either task to complete
106- let lifecycle_res = tokio:: select! {
107- res = & mut pubsub_to_client => {
108- let res = res?;
109- tracing:: debug!( ?res, "pubsub to WebSocket task completed" ) ;
110179 res
111- }
112- res = & mut client_to_pubsub => {
113- let res = res?;
114- tracing:: debug!( ?res, "WebSocket to pubsub task completed" ) ;
180+ } ,
181+ async {
182+ let res = ws_to_tunnel. await ?;
183+
184+ // Abort others if not aborted
185+ if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
186+ tracing:: debug!( ?res, "ws to tunnel task completed, aborting others" ) ;
187+
188+ let _ = ping_abort_tx2. send( ( ) ) ;
189+ let _ = tunnel_to_ws_abort_tx. send( ( ) ) ;
190+ } else {
191+ tracing:: debug!( ?res, "ws to tunnel task completed" ) ;
192+ }
193+
115194 res
116- }
117- res = & mut ping => {
118- let res = res?;
119- tracing:: debug!( ?res, "ping task completed" ) ;
195+ } ,
196+ async {
197+ let res = ping. await ?;
198+
199+ // Abort others if not aborted
200+ if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
201+ tracing:: debug!( ?res, "ping task completed, aborting others" ) ;
202+
203+ let _ = ws_to_tunnel_abort_tx2. send( ( ) ) ;
204+ let _ = tunnel_to_ws_abort_tx2. send( ( ) ) ;
205+ } else {
206+ tracing:: debug!( ?res, "ping task completed" ) ;
207+ }
208+
120209 res
121210 }
122- } ;
211+ ) ;
123212
124- // Abort remaining tasks
125- pubsub_to_client. abort ( ) ;
126- client_to_pubsub. abort ( ) ;
127- ping. abort ( ) ;
213+ // Determine single result from both tasks
214+ let lifecycle_res = match ( tunnel_to_ws_res, ws_to_tunnel_res, ping_res) {
215+ // Prefer error
216+ ( Err ( err) , _, _) => Err ( err) ,
217+ ( _, Err ( err) , _) => Err ( err) ,
218+ ( _, _, Err ( err) ) => Err ( err) ,
219+ // Prefer non aborted result if both succeed
220+ ( Ok ( res) , Ok ( LifecycleResult :: Aborted ) , _) => Ok ( res) ,
221+ ( Ok ( LifecycleResult :: Aborted ) , Ok ( res) , _) => Ok ( res) ,
222+ // Unlikely case
223+ ( res, _, _) => res,
224+ } ;
128225
129226 // Make runner immediately ineligible when it disconnects
130227 let update_alloc_res = self
@@ -177,10 +274,7 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe {
177274 . context ( "failed to serialize tunnel message for gateway" ) ?;
178275
179276 // Publish message to UPS
180- let res = self
181- . ctx
182- . ups ( )
183- . context ( "failed to get UPS instance for tunnel message" ) ?
277+ let res = ups
184278 . publish ( & req. gateway_reply_to , & msg_serialized, PublishOpts :: one ( ) )
185279 . await ;
186280
0 commit comments