From c9797c4af78b04148fcdb5132e626503a550d302 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Tue, 4 Nov 2025 18:26:11 +0000 Subject: [PATCH] chore(guard): handle cors on the gateway --- engine/packages/pegboard-gateway/src/lib.rs | 58 ++++++++++++++++++- .../packages/rivetkit/src/actor/router.ts | 29 ++++------ .../packages/rivetkit/src/common/cors.ts | 57 ++++++++++++++++++ .../packages/rivetkit/src/inspector/config.ts | 45 -------------- .../packages/rivetkit/src/manager/router.ts | 21 ++----- .../rivetkit/src/registry/run-config.ts | 6 -- 6 files changed, 130 insertions(+), 86 deletions(-) create mode 100644 rivetkit-typescript/packages/rivetkit/src/common/cors.ts diff --git a/engine/packages/pegboard-gateway/src/lib.rs b/engine/packages/pegboard-gateway/src/lib.rs index 5bbbd978b4..7542fdfe96 100644 --- a/engine/packages/pegboard-gateway/src/lib.rs +++ b/engine/packages/pegboard-gateway/src/lib.rs @@ -76,6 +76,15 @@ impl CustomServeTrait for PegboardGateway { // Use the actor ID from the gateway instance let actor_id = self.actor_id.to_string(); + // Extract origin for CORS (before consuming request) + // When credentials: true, we must echo back the actual origin, not "*" + let origin = req + .headers() + .get("origin") + .and_then(|v| v.to_str().ok()) + .unwrap_or("*") + .to_string(); + // Extract request parts let mut headers = HashableMap::new(); for (name, value) in req.headers() { @@ -87,6 +96,42 @@ impl CustomServeTrait for PegboardGateway { // Extract method and path before consuming the request let method = req.method().to_string(); + // Handle CORS preflight OPTIONS requests at gateway level + // + // We need to do this in Guard because there is no way of sending an OPTIONS request to the + // actor since we don't have the `x-rivet-token` header. This implementation allows + // requests from anywhere and lets the actor handle CORS manually in `onBeforeConnect`. + // This had the added benefit of also applying to WebSockets. + if req.method() == hyper::Method::OPTIONS { + tracing::debug!("handling OPTIONS preflight request at gateway"); + + // Extract requested headers + let requested_headers = req + .headers() + .get("access-control-request-headers") + .and_then(|v| v.to_str().ok()) + .unwrap_or("*"); + + let mut response = Response::builder() + .status(StatusCode::NO_CONTENT) + .header("access-control-allow-origin", &origin) + .header("access-control-allow-credentials", "true") + .header( + "access-control-allow-methods", + "GET, POST, PUT, DELETE, OPTIONS, PATCH", + ) + .header("access-control-allow-headers", requested_headers) + .header("access-control-expose-headers", "*") + .header("access-control-max-age", "86400"); + + // Add Vary header to prevent cache poisoning when echoing origin + if origin != "*" { + response = response.header("vary", "Origin"); + } + + return Ok(response.body(ResponseBody::Full(Full::new(Bytes::new())))?); + } + let body_bytes = req .into_body() .collect() @@ -164,11 +209,22 @@ impl CustomServeTrait for PegboardGateway { let mut response_builder = Response::builder().status(StatusCode::from_u16(response_start.status)?); - // Add headers + // Add headers from actor for (key, value) in response_start.headers { response_builder = response_builder.header(key, value); } + // Add CORS headers to actual request + response_builder = response_builder + .header("access-control-allow-origin", &origin) + .header("access-control-allow-credentials", "true") + .header("access-control-expose-headers", "*"); + + // Add Vary header to prevent cache poisoning when echoing origin + if origin != "*" { + response_builder = response_builder.header("vary", "Origin"); + } + // Add body let body = response_start.body.unwrap_or_default(); let response = response_builder.body(ResponseBody::Full(Full::new(Bytes::from(body))))?; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router.ts index 61b00b987b..1a3c863e6e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router.ts @@ -1,5 +1,4 @@ import { Hono, type Context as HonoContext } from "hono"; -import { cors } from "hono/cors"; import invariant from "invariant"; import { EncodingSchema } from "@/actor/protocol/serde"; import { @@ -320,22 +319,18 @@ export function createActorRouter( new Hono< ActorInspectorRouterEnv & { Bindings: ActorRouterBindings } >() - .use( - cors(runConfig.inspector.cors), - secureInspector(runConfig), - async (c, next) => { - const inspector = ( - await actorDriver.loadActor(c.env.actorId) - ).inspector; - invariant( - inspector, - "inspector not supported on this platform", - ); - - c.set("inspector", inspector); - return next(); - }, - ) + .use(secureInspector(runConfig), async (c, next) => { + const inspector = ( + await actorDriver.loadActor(c.env.actorId) + ).inspector; + invariant( + inspector, + "inspector not supported on this platform", + ); + + c.set("inspector", inspector); + return next(); + }) .route("/", createActorInspectorRouter()), ); } diff --git a/rivetkit-typescript/packages/rivetkit/src/common/cors.ts b/rivetkit-typescript/packages/rivetkit/src/common/cors.ts new file mode 100644 index 0000000000..3b63012fcc --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/common/cors.ts @@ -0,0 +1,57 @@ +import type { MiddlewareHandler } from "hono"; + +/** + * Simple CORS middleware that matches the gateway behavior. + * + * - Echoes back the Origin header from the request + * - Echoes back the Access-Control-Request-Headers from preflight + * - Supports credentials + * - Allows common HTTP methods + * - Caches preflight for 24 hours + * - Adds Vary header to prevent cache poisoning + */ +export const cors = (): MiddlewareHandler => { + return async (c, next) => { + // Extract origin from request + const origin = c.req.header("origin") || "*"; + + // Handle preflight OPTIONS request + if (c.req.method === "OPTIONS") { + const requestHeaders = + c.req.header("access-control-request-headers") || "*"; + + c.header("access-control-allow-origin", origin); + c.header("access-control-allow-credentials", "true"); + c.header( + "access-control-allow-methods", + "GET, POST, PUT, DELETE, OPTIONS, PATCH", + ); + c.header("access-control-allow-headers", requestHeaders); + c.header("access-control-expose-headers", "*"); + c.header("access-control-max-age", "86400"); + + // Add Vary header to prevent cache poisoning when echoing origin + if (origin !== "*") { + c.header("vary", "Origin"); + } + + // Remove content headers from preflight response + c.res.headers.delete("content-length"); + c.res.headers.delete("content-type"); + + return c.body(null, 204); + } + + await next(); + + // Add CORS headers to actual request + c.header("access-control-allow-origin", origin); + c.header("access-control-allow-credentials", "true"); + c.header("access-control-expose-headers", "*"); + + // Add Vary header to prevent cache poisoning when echoing origin + if (origin !== "*") { + c.header("vary", "Origin"); + } + }; +}; diff --git a/rivetkit-typescript/packages/rivetkit/src/inspector/config.ts b/rivetkit-typescript/packages/rivetkit/src/inspector/config.ts index e122eb7e65..69688f3c8e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/inspector/config.ts +++ b/rivetkit-typescript/packages/rivetkit/src/inspector/config.ts @@ -1,10 +1,6 @@ -import type { cors } from "hono/cors"; import { z } from "zod"; -import { HEADER_ACTOR_QUERY } from "@/driver-helpers/mod"; import { getEnvUniversal } from "@/utils"; -type CorsOptions = NonNullable[0]>; - const defaultTokenFn = () => { const envToken = getEnvUniversal("RIVETKIT_INSPECTOR_TOKEN"); @@ -22,41 +18,6 @@ const defaultEnabled = () => { ); }; -const defaultInspectorOrigins = [ - "http://localhost:43708", - "http://localhost:43709", - "https://studio.rivet.gg", - "https://inspect.rivet.dev", - "https://dashboard.rivet.dev", - "https://dashboard.staging.rivet.dev", -]; - -const defaultCors: CorsOptions = { - origin: (origin) => { - if ( - defaultInspectorOrigins.includes(origin) || - (origin.startsWith("https://") && - origin.endsWith("rivet-dev.vercel.app")) - ) { - return origin; - } else { - return null; - } - }, - allowMethods: ["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], - allowHeaders: [ - "Authorization", - "Content-Type", - "User-Agent", - "baggage", - "sentry-trace", - "x-rivet-actor", - "x-rivet-target", - ], - maxAge: 3600, - credentials: true, -}; - export const InspectorConfigSchema = z .object({ enabled: z @@ -69,11 +30,6 @@ export const InspectorConfigSchema = z ) .optional() .default(defaultEnabled), - /** CORS configuration for the router. Uses Hono's CORS middleware options. */ - cors: z - .custom() - .optional() - .default(() => defaultCors), /** * Token used to access the Inspector. @@ -95,6 +51,5 @@ export const InspectorConfigSchema = z .default(() => ({ enabled: defaultEnabled(), token: defaultTokenFn, - cors: defaultCors, })); export type InspectorConfig = z.infer; diff --git a/rivetkit-typescript/packages/rivetkit/src/manager/router.ts b/rivetkit-typescript/packages/rivetkit/src/manager/router.ts index 3954219b1e..ca6e6de1f1 100644 --- a/rivetkit-typescript/packages/rivetkit/src/manager/router.ts +++ b/rivetkit-typescript/packages/rivetkit/src/manager/router.ts @@ -6,7 +6,6 @@ import { type MiddlewareHandler, type Next, } from "hono"; -import { cors as corsMiddleware } from "hono/cors"; import { createMiddleware } from "hono/factory"; import { streamSSE } from "hono/streaming"; import invariant from "invariant"; @@ -23,13 +22,13 @@ import { WS_PROTOCOL_PATH, WS_PROTOCOL_TRANSPORT, } from "@/common/actor-router-consts"; +import { cors } from "@/common/cors"; import { handleHealthRequest, handleMetadataRequest, handleRouteError, handleRouteNotFound, loggerMiddleware, - type MetadataResponse, } from "@/common/router"; import { assertUnreachable, @@ -37,7 +36,7 @@ import { noopNext, stringifyError, } from "@/common/utils"; -import { type ActorDriver, HEADER_ACTOR_ID } from "@/driver-helpers/mod"; +import { HEADER_ACTOR_ID } from "@/driver-helpers/mod"; import type { TestInlineDriverCallRequest, TestInlineDriverCallResponse, @@ -45,11 +44,9 @@ import type { import { createManagerInspectorRouter } from "@/inspector/manager"; import { isInspectorEnabled, secureInspector } from "@/inspector/utils"; import { - type ActorsCreateRequest, ActorsCreateRequestSchema, type ActorsCreateResponse, ActorsCreateResponseSchema, - type ActorsGetOrCreateRequest, ActorsGetOrCreateRequestSchema, type ActorsGetOrCreateResponse, ActorsGetOrCreateResponseSchema, @@ -57,11 +54,9 @@ import { ActorsListResponseSchema, type Actor as ApiActor, } from "@/manager-api/actors"; -import { RivetIdSchema } from "@/manager-api/common"; import type { AnyClient } from "@/mod"; import type { RegistryConfig } from "@/registry/config"; import type { DriverConfig, RunnerConfig } from "@/registry/run-config"; -import { VERSION } from "@/utils"; import type { ActorOutput, ManagerDriver } from "./driver"; import { actorGateway, createTestWebSocketProxy } from "./gateway"; import { logger } from "./log"; @@ -97,7 +92,7 @@ export function createManagerRouter( runConfig.basePath, ); - router.use("*", loggerMiddleware(logger())); + router.use("*", loggerMiddleware(logger()), cors()); // HACK: Add Sec-WebSocket-Protocol header to fix KIT-339 // @@ -148,9 +143,6 @@ function addServerlessRoutes( client: AnyClient, router: OpenAPIHono, ) { - // Apply CORS - if (runConfig.cors) router.use("*", corsMiddleware(runConfig.cors)); - // GET / router.get("/", (c) => { return c.text( @@ -223,8 +215,7 @@ function addManagerRoutes( managerDriver: ManagerDriver, router: OpenAPIHono, ) { - // Serve inspector BEFORE the rest of the routes, since this has a special - // CORS config that should take precedence for the `/inspector` path + // Inspector if (isInspectorEnabled(runConfig, "manager")) { if (!managerDriver.inspector) { throw new Unsupported("inspector"); @@ -232,7 +223,6 @@ function addManagerRoutes( router.route( "/inspect", new Hono<{ Variables: { inspector: any } }>() - .use(corsMiddleware(runConfig.inspector.cors)) .use(secureInspector(runConfig)) .use((c, next) => { c.set("inspector", managerDriver.inspector!); @@ -242,9 +232,6 @@ function addManagerRoutes( ); } - // Apply CORS - if (runConfig.cors) router.use("*", corsMiddleware(runConfig.cors)); - // Actor gateway router.use("*", actorGateway.bind(undefined, runConfig, managerDriver)); diff --git a/rivetkit-typescript/packages/rivetkit/src/registry/run-config.ts b/rivetkit-typescript/packages/rivetkit/src/registry/run-config.ts index 866d900baf..9b42f3c294 100644 --- a/rivetkit-typescript/packages/rivetkit/src/registry/run-config.ts +++ b/rivetkit-typescript/packages/rivetkit/src/registry/run-config.ts @@ -1,4 +1,3 @@ -import type { cors } from "hono/cors"; import type { Logger } from "pino"; import { z } from "zod"; import type { ActorDriverBuilder } from "@/actor/driver"; @@ -9,8 +8,6 @@ import type { ManagerDriverBuilder } from "@/manager/driver"; import type { GetUpgradeWebSocket } from "@/utils"; import { getEnvUniversal } from "@/utils"; -type CorsOptions = NonNullable[0]>; - export const DriverConfigSchema = z.object({ /** Machine-readable name to identify this driver by. */ name: z.string(), @@ -25,9 +22,6 @@ export const RunnerConfigSchema = z .object({ driver: DriverConfigSchema.optional(), - /** CORS configuration for the router. Uses Hono's CORS middleware options. */ - cors: z.custom().optional(), - /** @experimental */ maxIncomingMessageSize: z.number().optional().default(65_536),