Skip to content

Commit e7f7c04

Browse files
committed
chore(guard): handle cors on the gateway
1 parent e86c567 commit e7f7c04

File tree

6 files changed

+130
-86
lines changed

6 files changed

+130
-86
lines changed

engine/packages/pegboard-gateway/src/lib.rs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ impl CustomServeTrait for PegboardGateway {
7676
// Use the actor ID from the gateway instance
7777
let actor_id = self.actor_id.to_string();
7878

79+
// Extract origin for CORS (before consuming request)
80+
// When credentials: true, we must echo back the actual origin, not "*"
81+
let origin = req
82+
.headers()
83+
.get("origin")
84+
.and_then(|v| v.to_str().ok())
85+
.unwrap_or("*")
86+
.to_string();
87+
7988
// Extract request parts
8089
let mut headers = HashableMap::new();
8190
for (name, value) in req.headers() {
@@ -87,6 +96,42 @@ impl CustomServeTrait for PegboardGateway {
8796
// Extract method and path before consuming the request
8897
let method = req.method().to_string();
8998

99+
// Handle CORS preflight OPTIONS requests at gateway level
100+
//
101+
// We need to do this in Guard because there is no way of sending an OPTIONS request to the
102+
// actor since we don't have the `x-rivet-token` header. This implementation allows
103+
// requests from anywhere and lets the actor handle CORS manually in `onBeforeConnect`.
104+
// This had the added benefit of also applying to WebSockets.
105+
if req.method() == hyper::Method::OPTIONS {
106+
tracing::debug!("handling OPTIONS preflight request at gateway");
107+
108+
// Extract requested headers
109+
let requested_headers = req
110+
.headers()
111+
.get("access-control-request-headers")
112+
.and_then(|v| v.to_str().ok())
113+
.unwrap_or("*");
114+
115+
let mut response = Response::builder()
116+
.status(StatusCode::NO_CONTENT)
117+
.header("access-control-allow-origin", &origin)
118+
.header("access-control-allow-credentials", "true")
119+
.header(
120+
"access-control-allow-methods",
121+
"GET, POST, PUT, DELETE, OPTIONS, PATCH",
122+
)
123+
.header("access-control-allow-headers", requested_headers)
124+
.header("access-control-expose-headers", "*")
125+
.header("access-control-max-age", "86400");
126+
127+
// Add Vary header to prevent cache poisoning when echoing origin
128+
if origin != "*" {
129+
response = response.header("vary", "Origin");
130+
}
131+
132+
return Ok(response.body(ResponseBody::Full(Full::new(Bytes::new())))?);
133+
}
134+
90135
let body_bytes = req
91136
.into_body()
92137
.collect()
@@ -164,11 +209,22 @@ impl CustomServeTrait for PegboardGateway {
164209
let mut response_builder =
165210
Response::builder().status(StatusCode::from_u16(response_start.status)?);
166211

167-
// Add headers
212+
// Add headers from actor
168213
for (key, value) in response_start.headers {
169214
response_builder = response_builder.header(key, value);
170215
}
171216

217+
// Add CORS headers to actual request
218+
response_builder = response_builder
219+
.header("access-control-allow-origin", &origin)
220+
.header("access-control-allow-credentials", "true")
221+
.header("access-control-expose-headers", "*");
222+
223+
// Add Vary header to prevent cache poisoning when echoing origin
224+
if origin != "*" {
225+
response_builder = response_builder.header("vary", "Origin");
226+
}
227+
172228
// Add body
173229
let body = response_start.body.unwrap_or_default();
174230
let response = response_builder.body(ResponseBody::Full(Full::new(Bytes::from(body))))?;

rivetkit-typescript/packages/rivetkit/src/actor/router.ts

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import { Hono, type Context as HonoContext } from "hono";
2-
import { cors } from "hono/cors";
32
import invariant from "invariant";
43
import { EncodingSchema } from "@/actor/protocol/serde";
54
import {
@@ -320,22 +319,18 @@ export function createActorRouter(
320319
new Hono<
321320
ActorInspectorRouterEnv & { Bindings: ActorRouterBindings }
322321
>()
323-
.use(
324-
cors(runConfig.inspector.cors),
325-
secureInspector(runConfig),
326-
async (c, next) => {
327-
const inspector = (
328-
await actorDriver.loadActor(c.env.actorId)
329-
).inspector;
330-
invariant(
331-
inspector,
332-
"inspector not supported on this platform",
333-
);
334-
335-
c.set("inspector", inspector);
336-
return next();
337-
},
338-
)
322+
.use(secureInspector(runConfig), async (c, next) => {
323+
const inspector = (
324+
await actorDriver.loadActor(c.env.actorId)
325+
).inspector;
326+
invariant(
327+
inspector,
328+
"inspector not supported on this platform",
329+
);
330+
331+
c.set("inspector", inspector);
332+
return next();
333+
})
339334
.route("/", createActorInspectorRouter()),
340335
);
341336
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import type { MiddlewareHandler } from "hono";
2+
3+
/**
4+
* Simple CORS middleware that matches the gateway behavior.
5+
*
6+
* - Echoes back the Origin header from the request
7+
* - Echoes back the Access-Control-Request-Headers from preflight
8+
* - Supports credentials
9+
* - Allows common HTTP methods
10+
* - Caches preflight for 24 hours
11+
* - Adds Vary header to prevent cache poisoning
12+
*/
13+
export const cors = (): MiddlewareHandler => {
14+
return async (c, next) => {
15+
// Extract origin from request
16+
const origin = c.req.header("origin") || "*";
17+
18+
// Handle preflight OPTIONS request
19+
if (c.req.method === "OPTIONS") {
20+
const requestHeaders =
21+
c.req.header("access-control-request-headers") || "*";
22+
23+
c.header("access-control-allow-origin", origin);
24+
c.header("access-control-allow-credentials", "true");
25+
c.header(
26+
"access-control-allow-methods",
27+
"GET, POST, PUT, DELETE, OPTIONS, PATCH",
28+
);
29+
c.header("access-control-allow-headers", requestHeaders);
30+
c.header("access-control-expose-headers", "*");
31+
c.header("access-control-max-age", "86400");
32+
33+
// Add Vary header to prevent cache poisoning when echoing origin
34+
if (origin !== "*") {
35+
c.header("vary", "Origin");
36+
}
37+
38+
// Remove content headers from preflight response
39+
c.res.headers.delete("content-length");
40+
c.res.headers.delete("content-type");
41+
42+
return c.body(null, 204);
43+
}
44+
45+
await next();
46+
47+
// Add CORS headers to actual request
48+
c.header("access-control-allow-origin", origin);
49+
c.header("access-control-allow-credentials", "true");
50+
c.header("access-control-expose-headers", "*");
51+
52+
// Add Vary header to prevent cache poisoning when echoing origin
53+
if (origin !== "*") {
54+
c.header("vary", "Origin");
55+
}
56+
};
57+
};
Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
1-
import type { cors } from "hono/cors";
21
import { z } from "zod";
3-
import { HEADER_ACTOR_QUERY } from "@/driver-helpers/mod";
42
import { getEnvUniversal } from "@/utils";
53

6-
type CorsOptions = NonNullable<Parameters<typeof cors>[0]>;
7-
84
const defaultTokenFn = () => {
95
const envToken = getEnvUniversal("RIVETKIT_INSPECTOR_TOKEN");
106

@@ -22,41 +18,6 @@ const defaultEnabled = () => {
2218
);
2319
};
2420

25-
const defaultInspectorOrigins = [
26-
"http://localhost:43708",
27-
"http://localhost:43709",
28-
"https://studio.rivet.gg",
29-
"https://inspect.rivet.dev",
30-
"https://dashboard.rivet.dev",
31-
"https://dashboard.staging.rivet.dev",
32-
];
33-
34-
const defaultCors: CorsOptions = {
35-
origin: (origin) => {
36-
if (
37-
defaultInspectorOrigins.includes(origin) ||
38-
(origin.startsWith("https://") &&
39-
origin.endsWith("rivet-dev.vercel.app"))
40-
) {
41-
return origin;
42-
} else {
43-
return null;
44-
}
45-
},
46-
allowMethods: ["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
47-
allowHeaders: [
48-
"Authorization",
49-
"Content-Type",
50-
"User-Agent",
51-
"baggage",
52-
"sentry-trace",
53-
"x-rivet-actor",
54-
"x-rivet-target",
55-
],
56-
maxAge: 3600,
57-
credentials: true,
58-
};
59-
6021
export const InspectorConfigSchema = z
6122
.object({
6223
enabled: z
@@ -69,11 +30,6 @@ export const InspectorConfigSchema = z
6930
)
7031
.optional()
7132
.default(defaultEnabled),
72-
/** CORS configuration for the router. Uses Hono's CORS middleware options. */
73-
cors: z
74-
.custom<CorsOptions>()
75-
.optional()
76-
.default(() => defaultCors),
7733

7834
/**
7935
* Token used to access the Inspector.
@@ -95,6 +51,5 @@ export const InspectorConfigSchema = z
9551
.default(() => ({
9652
enabled: defaultEnabled(),
9753
token: defaultTokenFn,
98-
cors: defaultCors,
9954
}));
10055
export type InspectorConfig = z.infer<typeof InspectorConfigSchema>;

rivetkit-typescript/packages/rivetkit/src/manager/router.ts

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import {
66
type MiddlewareHandler,
77
type Next,
88
} from "hono";
9-
import { cors as corsMiddleware } from "hono/cors";
109
import { createMiddleware } from "hono/factory";
1110
import { streamSSE } from "hono/streaming";
1211
import invariant from "invariant";
@@ -23,45 +22,41 @@ import {
2322
WS_PROTOCOL_PATH,
2423
WS_PROTOCOL_TRANSPORT,
2524
} from "@/common/actor-router-consts";
25+
import { cors } from "@/common/cors";
2626
import {
2727
handleHealthRequest,
2828
handleMetadataRequest,
2929
handleRouteError,
3030
handleRouteNotFound,
3131
loggerMiddleware,
32-
type MetadataResponse,
3332
} from "@/common/router";
3433
import {
3534
assertUnreachable,
3635
deconstructError,
3736
noopNext,
3837
stringifyError,
3938
} from "@/common/utils";
40-
import { type ActorDriver, HEADER_ACTOR_ID } from "@/driver-helpers/mod";
39+
import { HEADER_ACTOR_ID } from "@/driver-helpers/mod";
4140
import type {
4241
TestInlineDriverCallRequest,
4342
TestInlineDriverCallResponse,
4443
} from "@/driver-test-suite/test-inline-client-driver";
4544
import { createManagerInspectorRouter } from "@/inspector/manager";
4645
import { isInspectorEnabled, secureInspector } from "@/inspector/utils";
4746
import {
48-
type ActorsCreateRequest,
4947
ActorsCreateRequestSchema,
5048
type ActorsCreateResponse,
5149
ActorsCreateResponseSchema,
52-
type ActorsGetOrCreateRequest,
5350
ActorsGetOrCreateRequestSchema,
5451
type ActorsGetOrCreateResponse,
5552
ActorsGetOrCreateResponseSchema,
5653
type ActorsListResponse,
5754
ActorsListResponseSchema,
5855
type Actor as ApiActor,
5956
} from "@/manager-api/actors";
60-
import { RivetIdSchema } from "@/manager-api/common";
6157
import type { AnyClient } from "@/mod";
6258
import type { RegistryConfig } from "@/registry/config";
6359
import type { DriverConfig, RunnerConfig } from "@/registry/run-config";
64-
import { VERSION } from "@/utils";
6560
import type { ActorOutput, ManagerDriver } from "./driver";
6661
import { actorGateway, createTestWebSocketProxy } from "./gateway";
6762
import { logger } from "./log";
@@ -97,7 +92,7 @@ export function createManagerRouter(
9792
runConfig.basePath,
9893
);
9994

100-
router.use("*", loggerMiddleware(logger()));
95+
router.use("*", loggerMiddleware(logger()), cors());
10196

10297
// HACK: Add Sec-WebSocket-Protocol header to fix KIT-339
10398
//
@@ -148,9 +143,6 @@ function addServerlessRoutes(
148143
client: AnyClient,
149144
router: OpenAPIHono,
150145
) {
151-
// Apply CORS
152-
if (runConfig.cors) router.use("*", corsMiddleware(runConfig.cors));
153-
154146
// GET /
155147
router.get("/", (c) => {
156148
return c.text(
@@ -223,16 +215,14 @@ function addManagerRoutes(
223215
managerDriver: ManagerDriver,
224216
router: OpenAPIHono,
225217
) {
226-
// Serve inspector BEFORE the rest of the routes, since this has a special
227-
// CORS config that should take precedence for the `/inspector` path
218+
// Inspector
228219
if (isInspectorEnabled(runConfig, "manager")) {
229220
if (!managerDriver.inspector) {
230221
throw new Unsupported("inspector");
231222
}
232223
router.route(
233224
"/inspect",
234225
new Hono<{ Variables: { inspector: any } }>()
235-
.use(corsMiddleware(runConfig.inspector.cors))
236226
.use(secureInspector(runConfig))
237227
.use((c, next) => {
238228
c.set("inspector", managerDriver.inspector!);
@@ -242,9 +232,6 @@ function addManagerRoutes(
242232
);
243233
}
244234

245-
// Apply CORS
246-
if (runConfig.cors) router.use("*", corsMiddleware(runConfig.cors));
247-
248235
// Actor gateway
249236
router.use("*", actorGateway.bind(undefined, runConfig, managerDriver));
250237

rivetkit-typescript/packages/rivetkit/src/registry/run-config.ts

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import type { cors } from "hono/cors";
21
import type { Logger } from "pino";
32
import { z } from "zod";
43
import type { ActorDriverBuilder } from "@/actor/driver";
@@ -9,8 +8,6 @@ import type { ManagerDriverBuilder } from "@/manager/driver";
98
import type { GetUpgradeWebSocket } from "@/utils";
109
import { getEnvUniversal } from "@/utils";
1110

12-
type CorsOptions = NonNullable<Parameters<typeof cors>[0]>;
13-
1411
export const DriverConfigSchema = z.object({
1512
/** Machine-readable name to identify this driver by. */
1613
name: z.string(),
@@ -25,9 +22,6 @@ export const RunnerConfigSchema = z
2522
.object({
2623
driver: DriverConfigSchema.optional(),
2724

28-
/** CORS configuration for the router. Uses Hono's CORS middleware options. */
29-
cors: z.custom<CorsOptions>().optional(),
30-
3125
/** @experimental */
3226
maxIncomingMessageSize: z.number().optional().default(65_536),
3327

0 commit comments

Comments
 (0)