Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion engine/packages/pegboard-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ impl CustomServeTrait for PegboardGateway {
// Use the actor ID from the gateway instance
let actor_id = self.actor_id.to_string();

// Extract origin for CORS (before consuming request)
// When credentials: true, we must echo back the actual origin, not "*"
let origin = req
.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.unwrap_or("*")
.to_string();

// Extract request parts
let mut headers = HashableMap::new();
for (name, value) in req.headers() {
Expand All @@ -87,6 +96,42 @@ impl CustomServeTrait for PegboardGateway {
// Extract method and path before consuming the request
let method = req.method().to_string();

// Handle CORS preflight OPTIONS requests at gateway level
//
// We need to do this in Guard because there is no way of sending an OPTIONS request to the
// actor since we don't have the `x-rivet-token` header. This implementation allows
// requests from anywhere and lets the actor handle CORS manually in `onBeforeConnect`.
// This had the added benefit of also applying to WebSockets.
if req.method() == hyper::Method::OPTIONS {
tracing::debug!("handling OPTIONS preflight request at gateway");

// Extract requested headers
let requested_headers = req
.headers()
.get("access-control-request-headers")
.and_then(|v| v.to_str().ok())
.unwrap_or("*");

let mut response = Response::builder()
.status(StatusCode::NO_CONTENT)
.header("access-control-allow-origin", &origin)
.header("access-control-allow-credentials", "true")
.header(
"access-control-allow-methods",
"GET, POST, PUT, DELETE, OPTIONS, PATCH",
)
.header("access-control-allow-headers", requested_headers)
.header("access-control-expose-headers", "*")
.header("access-control-max-age", "86400");

// Add Vary header to prevent cache poisoning when echoing origin
if origin != "*" {
response = response.header("vary", "Origin");
}

return Ok(response.body(ResponseBody::Full(Full::new(Bytes::new())))?);
}

let body_bytes = req
.into_body()
.collect()
Expand Down Expand Up @@ -164,11 +209,22 @@ impl CustomServeTrait for PegboardGateway {
let mut response_builder =
Response::builder().status(StatusCode::from_u16(response_start.status)?);

// Add headers
// Add headers from actor
for (key, value) in response_start.headers {
response_builder = response_builder.header(key, value);
}

// Add CORS headers to actual request
response_builder = response_builder
.header("access-control-allow-origin", &origin)
.header("access-control-allow-credentials", "true")
.header("access-control-expose-headers", "*");

// Add Vary header to prevent cache poisoning when echoing origin
if origin != "*" {
response_builder = response_builder.header("vary", "Origin");
}

// Add body
let body = response_start.body.unwrap_or_default();
let response = response_builder.body(ResponseBody::Full(Full::new(Bytes::from(body))))?;
Expand Down
29 changes: 12 additions & 17 deletions rivetkit-typescript/packages/rivetkit/src/actor/router.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { Hono, type Context as HonoContext } from "hono";
import { cors } from "hono/cors";
import invariant from "invariant";
import { EncodingSchema } from "@/actor/protocol/serde";
import {
Expand Down Expand Up @@ -320,22 +319,18 @@ export function createActorRouter(
new Hono<
ActorInspectorRouterEnv & { Bindings: ActorRouterBindings }
>()
.use(
cors(runConfig.inspector.cors),
secureInspector(runConfig),
async (c, next) => {
const inspector = (
await actorDriver.loadActor(c.env.actorId)
).inspector;
invariant(
inspector,
"inspector not supported on this platform",
);

c.set("inspector", inspector);
return next();
},
)
.use(secureInspector(runConfig), async (c, next) => {
const inspector = (
await actorDriver.loadActor(c.env.actorId)
).inspector;
invariant(
inspector,
"inspector not supported on this platform",
);

c.set("inspector", inspector);
return next();
})
.route("/", createActorInspectorRouter()),
);
}
Expand Down
57 changes: 57 additions & 0 deletions rivetkit-typescript/packages/rivetkit/src/common/cors.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import type { MiddlewareHandler } from "hono";

/**
* Simple CORS middleware that matches the gateway behavior.
*
* - Echoes back the Origin header from the request
* - Echoes back the Access-Control-Request-Headers from preflight
* - Supports credentials
* - Allows common HTTP methods
* - Caches preflight for 24 hours
* - Adds Vary header to prevent cache poisoning
*/
export const cors = (): MiddlewareHandler => {
return async (c, next) => {
// Extract origin from request
const origin = c.req.header("origin") || "*";

// Handle preflight OPTIONS request
if (c.req.method === "OPTIONS") {
const requestHeaders =
c.req.header("access-control-request-headers") || "*";

c.header("access-control-allow-origin", origin);
c.header("access-control-allow-credentials", "true");
c.header(
"access-control-allow-methods",
"GET, POST, PUT, DELETE, OPTIONS, PATCH",
);
c.header("access-control-allow-headers", requestHeaders);
c.header("access-control-expose-headers", "*");
c.header("access-control-max-age", "86400");

// Add Vary header to prevent cache poisoning when echoing origin
if (origin !== "*") {
c.header("vary", "Origin");
}

// Remove content headers from preflight response
c.res.headers.delete("content-length");
c.res.headers.delete("content-type");

return c.body(null, 204);
}

await next();

// Add CORS headers to actual request
c.header("access-control-allow-origin", origin);
c.header("access-control-allow-credentials", "true");
c.header("access-control-expose-headers", "*");

// Add Vary header to prevent cache poisoning when echoing origin
if (origin !== "*") {
c.header("vary", "Origin");
}
};
};
45 changes: 0 additions & 45 deletions rivetkit-typescript/packages/rivetkit/src/inspector/config.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import type { cors } from "hono/cors";
import { z } from "zod";
import { HEADER_ACTOR_QUERY } from "@/driver-helpers/mod";
import { getEnvUniversal } from "@/utils";

type CorsOptions = NonNullable<Parameters<typeof cors>[0]>;

const defaultTokenFn = () => {
const envToken = getEnvUniversal("RIVETKIT_INSPECTOR_TOKEN");

Expand All @@ -22,41 +18,6 @@ const defaultEnabled = () => {
);
};

const defaultInspectorOrigins = [
"http://localhost:43708",
"http://localhost:43709",
"https://studio.rivet.gg",
"https://inspect.rivet.dev",
"https://dashboard.rivet.dev",
"https://dashboard.staging.rivet.dev",
];

const defaultCors: CorsOptions = {
origin: (origin) => {
if (
defaultInspectorOrigins.includes(origin) ||
(origin.startsWith("https://") &&
origin.endsWith("rivet-dev.vercel.app"))
) {
return origin;
} else {
return null;
}
},
allowMethods: ["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allowHeaders: [
"Authorization",
"Content-Type",
"User-Agent",
"baggage",
"sentry-trace",
"x-rivet-actor",
"x-rivet-target",
],
maxAge: 3600,
credentials: true,
};

export const InspectorConfigSchema = z
.object({
enabled: z
Expand All @@ -69,11 +30,6 @@ export const InspectorConfigSchema = z
)
.optional()
.default(defaultEnabled),
/** CORS configuration for the router. Uses Hono's CORS middleware options. */
cors: z
.custom<CorsOptions>()
.optional()
.default(() => defaultCors),

/**
* Token used to access the Inspector.
Expand All @@ -95,6 +51,5 @@ export const InspectorConfigSchema = z
.default(() => ({
enabled: defaultEnabled(),
token: defaultTokenFn,
cors: defaultCors,
}));
export type InspectorConfig = z.infer<typeof InspectorConfigSchema>;
21 changes: 4 additions & 17 deletions rivetkit-typescript/packages/rivetkit/src/manager/router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import {
type MiddlewareHandler,
type Next,
} from "hono";
import { cors as corsMiddleware } from "hono/cors";
import { createMiddleware } from "hono/factory";
import { streamSSE } from "hono/streaming";
import invariant from "invariant";
Expand All @@ -23,45 +22,41 @@ import {
WS_PROTOCOL_PATH,
WS_PROTOCOL_TRANSPORT,
} from "@/common/actor-router-consts";
import { cors } from "@/common/cors";
import {
handleHealthRequest,
handleMetadataRequest,
handleRouteError,
handleRouteNotFound,
loggerMiddleware,
type MetadataResponse,
} from "@/common/router";
import {
assertUnreachable,
deconstructError,
noopNext,
stringifyError,
} from "@/common/utils";
import { type ActorDriver, HEADER_ACTOR_ID } from "@/driver-helpers/mod";
import { HEADER_ACTOR_ID } from "@/driver-helpers/mod";
import type {
TestInlineDriverCallRequest,
TestInlineDriverCallResponse,
} from "@/driver-test-suite/test-inline-client-driver";
import { createManagerInspectorRouter } from "@/inspector/manager";
import { isInspectorEnabled, secureInspector } from "@/inspector/utils";
import {
type ActorsCreateRequest,
ActorsCreateRequestSchema,
type ActorsCreateResponse,
ActorsCreateResponseSchema,
type ActorsGetOrCreateRequest,
ActorsGetOrCreateRequestSchema,
type ActorsGetOrCreateResponse,
ActorsGetOrCreateResponseSchema,
type ActorsListResponse,
ActorsListResponseSchema,
type Actor as ApiActor,
} from "@/manager-api/actors";
import { RivetIdSchema } from "@/manager-api/common";
import type { AnyClient } from "@/mod";
import type { RegistryConfig } from "@/registry/config";
import type { DriverConfig, RunnerConfig } from "@/registry/run-config";
import { VERSION } from "@/utils";
import type { ActorOutput, ManagerDriver } from "./driver";
import { actorGateway, createTestWebSocketProxy } from "./gateway";
import { logger } from "./log";
Expand Down Expand Up @@ -97,7 +92,7 @@ export function createManagerRouter(
runConfig.basePath,
);

router.use("*", loggerMiddleware(logger()));
router.use("*", loggerMiddleware(logger()), cors());

// HACK: Add Sec-WebSocket-Protocol header to fix KIT-339
//
Expand Down Expand Up @@ -148,9 +143,6 @@ function addServerlessRoutes(
client: AnyClient,
router: OpenAPIHono,
) {
// Apply CORS
if (runConfig.cors) router.use("*", corsMiddleware(runConfig.cors));

// GET /
router.get("/", (c) => {
return c.text(
Expand Down Expand Up @@ -223,16 +215,14 @@ function addManagerRoutes(
managerDriver: ManagerDriver,
router: OpenAPIHono,
) {
// Serve inspector BEFORE the rest of the routes, since this has a special
// CORS config that should take precedence for the `/inspector` path
// Inspector
if (isInspectorEnabled(runConfig, "manager")) {
if (!managerDriver.inspector) {
throw new Unsupported("inspector");
}
router.route(
"/inspect",
new Hono<{ Variables: { inspector: any } }>()
.use(corsMiddleware(runConfig.inspector.cors))
.use(secureInspector(runConfig))
.use((c, next) => {
c.set("inspector", managerDriver.inspector!);
Expand All @@ -242,9 +232,6 @@ function addManagerRoutes(
);
}

// Apply CORS
if (runConfig.cors) router.use("*", corsMiddleware(runConfig.cors));

// Actor gateway
router.use("*", actorGateway.bind(undefined, runConfig, managerDriver));

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import type { cors } from "hono/cors";
import type { Logger } from "pino";
import { z } from "zod";
import type { ActorDriverBuilder } from "@/actor/driver";
Expand All @@ -9,8 +8,6 @@ import type { ManagerDriverBuilder } from "@/manager/driver";
import type { GetUpgradeWebSocket } from "@/utils";
import { getEnvUniversal } from "@/utils";

type CorsOptions = NonNullable<Parameters<typeof cors>[0]>;

export const DriverConfigSchema = z.object({
/** Machine-readable name to identify this driver by. */
name: z.string(),
Expand All @@ -25,9 +22,6 @@ export const RunnerConfigSchema = z
.object({
driver: DriverConfigSchema.optional(),

/** CORS configuration for the router. Uses Hono's CORS middleware options. */
cors: z.custom<CorsOptions>().optional(),

/** @experimental */
maxIncomingMessageSize: z.number().optional().default(65_536),

Expand Down
Loading