Skip to content

Commit 50142b7

Browse files
committed
chore(guard): add support for routing runner ws to /runners/connect
1 parent 7ee6720 commit 50142b7

File tree

3 files changed

+51
-8
lines changed

3 files changed

+51
-8
lines changed

engine/packages/guard/src/routing/mod.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
5151
.map(|v| v.eq_ignore_ascii_case("websocket"))
5252
.unwrap_or(false);
5353

54-
// First, check if this is an actor path-based route
54+
// MARK: Path-based routing
55+
// Route actor
5556
if let Some(actor_path_info) = parse_actor_path(path) {
5657
tracing::debug!(?actor_path_info, "routing using path-based actor routing");
5758

@@ -71,8 +72,15 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
7172
}
7273
}
7374

74-
// Fallback to header-based routing
75-
// Extract target from WebSocket protocol or HTTP header
75+
// Route runner
76+
if let Some(routing_output) =
77+
runner::route_request_path_based(&ctx, host, path, headers).await?
78+
{
79+
return Ok(routing_output);
80+
}
81+
82+
// MARK: Header- & protocol-based routing (X-Rivet-Target)
83+
// Determine target
7684
let target = if is_websocket {
7785
// For WebSocket, parse the sec-websocket-protocol header
7886
headers

engine/packages/guard/src/routing/runner.rs

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::sync::Arc;
66
use super::{SEC_WEBSOCKET_PROTOCOL, X_RIVET_TOKEN};
77
pub(crate) const WS_PROTOCOL_TOKEN: &str = "rivet_token.";
88

9-
/// Route requests to the API service
9+
/// Route requests to the runner service using header-based routing
1010
#[tracing::instrument(skip_all)]
1111
pub async fn route_request(
1212
ctx: &StandaloneCtx,
@@ -19,8 +19,38 @@ pub async fn route_request(
1919
return Ok(None);
2020
}
2121

22-
tracing::debug!(?host, path, "routing to runner");
22+
tracing::debug!(?host, path, "routing to runner via header");
2323

24+
route_runner_internal(ctx, host, headers).await.map(Some)
25+
}
26+
27+
/// Route requests to the runner service using path-based routing
28+
/// Matches path: /runners/connect
29+
#[tracing::instrument(skip_all)]
30+
pub async fn route_request_path_based(
31+
ctx: &StandaloneCtx,
32+
host: &str,
33+
path: &str,
34+
headers: &hyper::HeaderMap,
35+
) -> Result<Option<RoutingOutput>> {
36+
// Check if path matches /runners/connect
37+
let path_without_query = path.split('?').next().unwrap_or(path);
38+
if path_without_query != "/runners/connect" {
39+
return Ok(None);
40+
}
41+
42+
tracing::debug!(?host, path, "routing to runner via path");
43+
44+
route_runner_internal(ctx, host, headers).await.map(Some)
45+
}
46+
47+
/// Internal runner routing logic shared by both header-based and path-based routing
48+
#[tracing::instrument(skip_all)]
49+
async fn route_runner_internal(
50+
ctx: &StandaloneCtx,
51+
host: &str,
52+
headers: &hyper::HeaderMap,
53+
) -> Result<RoutingOutput> {
2454
// Validate that the host is valid for the current datacenter
2555
let current_dc = ctx.config().topology().current_dc()?;
2656
if !current_dc.is_valid_regional_host(host) {
@@ -95,5 +125,5 @@ pub async fn route_request(
95125
}
96126

97127
let tunnel = pegboard_runner::PegboardRunnerWsCustomServe::new(ctx.clone());
98-
Ok(Some(RoutingOutput::CustomServe(Arc::new(tunnel))))
128+
Ok(RoutingOutput::CustomServe(Arc::new(tunnel)))
99129
}

engine/sdks/typescript/runner/src/mod.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,12 +454,17 @@ export class Runner {
454454
const wsEndpoint = endpoint
455455
.replace("http://", "ws://")
456456
.replace("https://", "wss://");
457-
return `${wsEndpoint}?protocol_version=${PROTOCOL_VERSION}&namespace=${encodeURIComponent(this.#config.namespace)}&runner_key=${encodeURIComponent(this.#config.runnerKey)}`;
457+
458+
// Ensure the endpoint ends with /runners/connect
459+
const baseUrl = wsEndpoint.endsWith("/")
460+
? wsEndpoint.slice(0, -1)
461+
: wsEndpoint;
462+
return `${baseUrl}/runners/connect?protocol_version=${PROTOCOL_VERSION}&namespace=${encodeURIComponent(this.#config.namespace)}&runner_key=${encodeURIComponent(this.#config.runnerKey)}`;
458463
}
459464

460465
// MARK: Runner protocol
461466
async #openPegboardWebSocket() {
462-
const protocols = ["rivet", `rivet_target.runner`];
467+
const protocols = ["rivet"];
463468
if (this.config.token)
464469
protocols.push(`rivet_token.${this.config.token}`);
465470

0 commit comments

Comments
 (0)