From e86c5673f99f903c9feed2514e56877243b71371 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Tue, 4 Nov 2025 17:45:07 +0000 Subject: [PATCH] chore(guard): add support for routing runner ws to `/runners/connect` --- engine/packages/guard/src/routing/mod.rs | 14 ++++++-- engine/packages/guard/src/routing/runner.rs | 36 +++++++++++++++++++-- engine/sdks/typescript/runner/src/mod.ts | 9 ++++-- 3 files changed, 51 insertions(+), 8 deletions(-) diff --git a/engine/packages/guard/src/routing/mod.rs b/engine/packages/guard/src/routing/mod.rs index 3424196e0b..35aaef552c 100644 --- a/engine/packages/guard/src/routing/mod.rs +++ b/engine/packages/guard/src/routing/mod.rs @@ -51,7 +51,8 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) -> .map(|v| v.eq_ignore_ascii_case("websocket")) .unwrap_or(false); - // First, check if this is an actor path-based route + // MARK: Path-based routing + // Route actor if let Some(actor_path_info) = parse_actor_path(path) { tracing::debug!(?actor_path_info, "routing using path-based actor routing"); @@ -71,8 +72,15 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) -> } } - // Fallback to header-based routing - // Extract target from WebSocket protocol or HTTP header + // Route runner + if let Some(routing_output) = + runner::route_request_path_based(&ctx, host, path, headers).await? + { + return Ok(routing_output); + } + + // MARK: Header- & protocol-based routing (X-Rivet-Target) + // Determine target let target = if is_websocket { // For WebSocket, parse the sec-websocket-protocol header headers diff --git a/engine/packages/guard/src/routing/runner.rs b/engine/packages/guard/src/routing/runner.rs index b963c1895d..f090d3b49c 100644 --- a/engine/packages/guard/src/routing/runner.rs +++ b/engine/packages/guard/src/routing/runner.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use super::{SEC_WEBSOCKET_PROTOCOL, X_RIVET_TOKEN}; pub(crate) const WS_PROTOCOL_TOKEN: &str = "rivet_token."; -/// Route requests to the API service +/// Route requests to the runner service using header-based routing #[tracing::instrument(skip_all)] pub async fn route_request( ctx: &StandaloneCtx, @@ -19,8 +19,38 @@ pub async fn route_request( return Ok(None); } - tracing::debug!(?host, path, "routing to runner"); + tracing::debug!(?host, path, "routing to runner via header"); + route_runner_internal(ctx, host, headers).await.map(Some) +} + +/// Route requests to the runner service using path-based routing +/// Matches path: /runners/connect +#[tracing::instrument(skip_all)] +pub async fn route_request_path_based( + ctx: &StandaloneCtx, + host: &str, + path: &str, + headers: &hyper::HeaderMap, +) -> Result> { + // Check if path matches /runners/connect + let path_without_query = path.split('?').next().unwrap_or(path); + if path_without_query != "/runners/connect" { + return Ok(None); + } + + tracing::debug!(?host, path, "routing to runner via path"); + + route_runner_internal(ctx, host, headers).await.map(Some) +} + +/// Internal runner routing logic shared by both header-based and path-based routing +#[tracing::instrument(skip_all)] +async fn route_runner_internal( + ctx: &StandaloneCtx, + host: &str, + headers: &hyper::HeaderMap, +) -> Result { // Validate that the host is valid for the current datacenter let current_dc = ctx.config().topology().current_dc()?; if !current_dc.is_valid_regional_host(host) { @@ -95,5 +125,5 @@ pub async fn route_request( } let tunnel = pegboard_runner::PegboardRunnerWsCustomServe::new(ctx.clone()); - Ok(Some(RoutingOutput::CustomServe(Arc::new(tunnel)))) + Ok(RoutingOutput::CustomServe(Arc::new(tunnel))) } diff --git a/engine/sdks/typescript/runner/src/mod.ts b/engine/sdks/typescript/runner/src/mod.ts index f29e6d2a37..aa403f8bc4 100644 --- a/engine/sdks/typescript/runner/src/mod.ts +++ b/engine/sdks/typescript/runner/src/mod.ts @@ -454,12 +454,17 @@ export class Runner { const wsEndpoint = endpoint .replace("http://", "ws://") .replace("https://", "wss://"); - return `${wsEndpoint}?protocol_version=${PROTOCOL_VERSION}&namespace=${encodeURIComponent(this.#config.namespace)}&runner_key=${encodeURIComponent(this.#config.runnerKey)}`; + + // Ensure the endpoint ends with /runners/connect + const baseUrl = wsEndpoint.endsWith("/") + ? wsEndpoint.slice(0, -1) + : wsEndpoint; + return `${baseUrl}/runners/connect?protocol_version=${PROTOCOL_VERSION}&namespace=${encodeURIComponent(this.#config.namespace)}&runner_key=${encodeURIComponent(this.#config.runnerKey)}`; } // MARK: Runner protocol async #openPegboardWebSocket() { - const protocols = ["rivet", `rivet_target.runner`]; + const protocols = ["rivet"]; if (this.config.token) protocols.push(`rivet_token.${this.config.token}`);