diff --git a/engine/sdks/typescript/runner/src/mod.ts b/engine/sdks/typescript/runner/src/mod.ts index 4603cc35f7..64c4e3c309 100644 --- a/engine/sdks/typescript/runner/src/mod.ts +++ b/engine/sdks/typescript/runner/src/mod.ts @@ -62,7 +62,7 @@ export interface RunnerConfig { config: ActorConfig, ) => Promise; onActorStop: (actorId: string, generation: number) => Promise; - getActorHibernationConfig: (actorId: string, requestId: ArrayBuffer) => HibernationConfig; + getActorHibernationConfig: (actorId: string, requestId: ArrayBuffer, request: Request) => HibernationConfig; noAutoShutdown?: boolean; } diff --git a/engine/sdks/typescript/runner/src/tunnel.ts b/engine/sdks/typescript/runner/src/tunnel.ts index 9882341bc1..caa78648a8 100644 --- a/engine/sdks/typescript/runner/src/tunnel.ts +++ b/engine/sdks/typescript/runner/src/tunnel.ts @@ -541,21 +541,10 @@ export class Tunnel { // Store adapter this.#actorWebSockets.set(webSocketId, adapter); - // Send open confirmation - let hibernationConfig = this.#runner.config.getActorHibernationConfig(actor.actorId, requestId); - this.#sendMessage(requestId, { - tag: "ToServerWebSocketOpen", - val: { - canHibernate: hibernationConfig.enabled, - lastMsgIndex: BigInt(hibernationConfig.lastMsgIndex ?? -1), - }, - }); - - // Notify adapter that connection is open - adapter._handleOpen(requestId); - - // Create a minimal request object for the websocket handler - // Include original headers from the open message + // Convert headers to map + // + // We need to manually ensure the original Upgrade/Connection WS + // headers are present const headerInit: Record = {}; if (open.headers) { for (const [k, v] of open.headers as ReadonlyMap< @@ -565,7 +554,6 @@ export class Tunnel { headerInit[k] = v; } } - // Ensure websocket upgrade headers are present headerInit["Upgrade"] = "websocket"; headerInit["Connection"] = "Upgrade"; @@ -574,6 +562,21 @@ export class Tunnel { headers: headerInit, }); + // Send open confirmation + let hibernationConfig = this.#runner.config.getActorHibernationConfig(actor.actorId, requestId, request); + this.#sendMessage(requestId, { + tag: "ToServerWebSocketOpen", + val: { + canHibernate: hibernationConfig.enabled, + lastMsgIndex: BigInt(hibernationConfig.lastMsgIndex ?? -1), + }, + }); + + // Notify adapter that connection is open + adapter._handleOpen(requestId); + + + // Call websocket handler await websocketHandler( this.#runner, diff --git a/rivetkit-openapi/openapi.json b/rivetkit-openapi/openapi.json index 61a69f7fca..bbfd6b1713 100644 --- a/rivetkit-openapi/openapi.json +++ b/rivetkit-openapi/openapi.json @@ -1,7 +1,7 @@ { "openapi": "3.0.0", "info": { - "version": "2.0.21", + "version": "2.0.22-rc.1", "title": "RivetKit API" }, "components": { diff --git a/rivetkit-typescript/packages/rivetkit/package.json b/rivetkit-typescript/packages/rivetkit/package.json index 4beb9ce0d4..6b5c679b7a 100644 --- a/rivetkit-typescript/packages/rivetkit/package.json +++ b/rivetkit-typescript/packages/rivetkit/package.json @@ -153,7 +153,7 @@ ], "scripts": { "build": "tsup src/mod.ts src/client/mod.ts src/common/log.ts src/common/websocket.ts src/actor/errors.ts src/topologies/coordinate/mod.ts src/topologies/partition/mod.ts src/utils.ts src/driver-helpers/mod.ts src/driver-test-suite/mod.ts src/test/mod.ts src/inspector/mod.ts", - "build:schema": "./scripts/compile-bare.ts compile schemas/client-protocol/v1.bare -o dist/schemas/client-protocol/v1.ts && ./scripts/compile-bare.ts compile schemas/file-system-driver/v1.bare -o dist/schemas/file-system-driver/v1.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v1.bare -o dist/schemas/actor-persist/v1.ts", + "build:schema": "./scripts/compile-bare.ts compile schemas/client-protocol/v1.bare -o dist/schemas/client-protocol/v1.ts && ./scripts/compile-bare.ts compile schemas/file-system-driver/v1.bare -o dist/schemas/file-system-driver/v1.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v1.bare -o dist/schemas/actor-persist/v1.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v2.bare -o dist/schemas/actor-persist/v2.ts", "check-types": "tsc --noEmit", "test": "vitest run", "test:watch": "vitest", diff --git a/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v2.bare b/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v2.bare new file mode 100644 index 0000000000..89091d8976 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v2.bare @@ -0,0 +1,54 @@ +# MARK: Connection +# Represents an event subscription. +type PersistedSubscription struct { + # Event name + eventName: str +} + +type PersistedConnection struct { + id: str + token: str + parameters: data + state: data + subscriptions: list + lastSeen: u64 +} + +# MARK: Schedule Event +type GenericPersistedScheduleEvent struct { + # Action name + action: str + # Arguments for the action + # + # CBOR array + args: optional +} + +type PersistedScheduleEventKind union { + GenericPersistedScheduleEvent +} + +type PersistedScheduleEvent struct { + eventId: str + timestamp: u64 + kind: PersistedScheduleEventKind +} + +# MARK: WebSocket +type PersistedHibernatableWebSocket struct { + requestId: data + lastSeenTimestamp: u64 + msgIndex: u64 +} + +# MARK: Actor +# Represents the persisted state of an actor. +type PersistedActor struct { + # Input data passed to the actor on initialization + input: optional + hasInitialized: bool + state: data + connections: list + scheduledEvents: list + hibernatableWebSocket: list +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/config.ts b/rivetkit-typescript/packages/rivetkit/src/actor/config.ts index f5152eb3bd..38a40b9262 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/config.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/config.ts @@ -70,6 +70,16 @@ export const ActorConfigSchema = z connectionLivenessInterval: z.number().positive().default(5000), noSleep: z.boolean().default(false), sleepTimeout: z.number().positive().default(30_000), + /** @experimental */ + canHibernatWebSocket: z + .union([ + z.boolean(), + z + .function() + .args(z.custom()) + .returns(z.boolean()), + ]) + .default(false), }) .strict() .default({}), diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn-drivers.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn-drivers.ts index 90daac858f..e68456a5ba 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn-drivers.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn-drivers.ts @@ -5,6 +5,7 @@ import type { AnyConn } from "@/actor/conn"; import type { AnyActorInstance } from "@/actor/instance"; import type { CachedSerializer, Encoding } from "@/actor/protocol/serde"; import { encodeDataToString } from "@/actor/protocol/serde"; +import type { HonoWebSocketAdapter } from "@/manager/hono-websocket-adapter"; import type * as protocol from "@/schemas/client-protocol/mod"; import { assertUnreachable, type promiseWithResolvers } from "@/utils"; @@ -67,6 +68,15 @@ export interface ConnDriver { conn: AnyConn, state: State, ): ConnReadyState | undefined; + + /** + * If the underlying connection can hibernate. + */ + isHibernatable( + actor: AnyActorInstance, + conn: AnyConn, + state: State, + ): boolean; } // MARK: WebSocket @@ -140,6 +150,22 @@ const WEBSOCKET_DRIVER: ConnDriver = { ): ConnReadyState | undefined => { return state.websocket.readyState; }, + + isHibernatable( + _actor: AnyActorInstance, + _conn: AnyConn, + state: ConnDriverWebSocketState, + ): boolean { + // Extract isHibernatable from the HonoWebSocketAdapter + if (state.websocket.raw) { + const raw = state.websocket.raw as HonoWebSocketAdapter; + if (typeof raw.isHibernatable === "boolean") { + return raw.isHibernatable; + } + } + + return false; + }, }; // MARK: SSE @@ -175,6 +201,10 @@ const SSE_DRIVER: ConnDriver = { return ConnReadyState.OPEN; }, + + isHibernatable(): boolean { + return false; + }, }; // MARK: HTTP @@ -187,6 +217,9 @@ const HTTP_DRIVER: ConnDriver = { // Noop // TODO: Abort the request }, + isHibernatable(): boolean { + return false; + }, }; /** List of all connection drivers. */ diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn.ts index 83bdd0a4f9..d268dbd142 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn.ts @@ -1,5 +1,6 @@ import * as cbor from "cbor-x"; import invariant from "invariant"; +import { PersistedHibernatableWebSocket } from "@/schemas/actor-persist/mod"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; import { bufferToArrayBuffer } from "@/utils"; @@ -125,6 +126,25 @@ export class Conn { return this.__status; } + /** + * @experimental + * + * If the underlying connection can hibernate. + */ + public get isHibernatable(): boolean { + if (this.__driverState) { + const driverKind = getConnDriverKindFromState(this.__driverState); + const driver = CONN_DRIVERS[driverKind]; + return driver.isHibernatable( + this.#actor, + this, + (this.__driverState as any)[driverKind], + ); + } else { + return false; + } + } + /** * Timestamp of the last time the connection was seen, i.e. the last time the connection was active and checked for liveness. */ diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance.ts index a226028fd2..6e70237073 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance.ts @@ -15,6 +15,7 @@ import { PERSISTED_ACTOR_VERSIONED } from "@/schemas/actor-persist/versioned"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; import { + arrayBuffersEqual, bufferToArrayBuffer, getEnvUniversal, promiseWithResolvers, @@ -45,6 +46,7 @@ import { serializeActorKey } from "./keys"; import type { PersistedActor, PersistedConn, + PersistedHibernatableWebSocket, PersistedScheduleEvent, } from "./persisted"; import { processMessage } from "./protocol/old"; @@ -52,6 +54,8 @@ import { CachedSerializer } from "./protocol/serde"; import { Schedule } from "./schedule"; import { DeadlineError, deadline } from "./utils"; +export const PERSIST_SYMBOL = Symbol("persist"); + /** * Options for the `_saveState` method. */ @@ -158,6 +162,10 @@ export class ActorInstance { */ #persist!: PersistedActor; + get [PERSIST_SYMBOL](): PersistedActor { + return this.#persist; + } + /** Raw state without the proxy wrapper */ #persistRaw!: PersistedActor; @@ -1534,17 +1542,116 @@ export class ActorInstance { this.#activeRawWebSockets.add(websocket); this.#resetSleepTimer(); - // Track socket close - const onSocketClosed = () => { + // Track hibernatable WebSockets + let rivetRequestId: ArrayBuffer | undefined; + let persistedHibernatableWebSocket: + | PersistedHibernatableWebSocket + | undefined; + + const onSocketOpened = (event: any) => { + rivetRequestId = event?.rivetRequestId; + + // Find hibernatable WS + if (rivetRequestId) { + const rivetRequestIdLocal = rivetRequestId; + persistedHibernatableWebSocket = + this.#persist.hibernatableWebSocket.find((ws) => + arrayBuffersEqual( + ws.requestId, + rivetRequestIdLocal, + ), + ); + + if (persistedHibernatableWebSocket) { + persistedHibernatableWebSocket.lastSeenTimestamp = + BigInt(Date.now()); + } + } + + this.#rLog.debug({ + msg: "actor instance onSocketOpened", + rivetRequestId, + isHibernatable: !!persistedHibernatableWebSocket, + hibernationMsgIndex: + persistedHibernatableWebSocket?.msgIndex, + }); + }; + + const onSocketMessage = (event: any) => { + // Update state of hibernatable WS + if (persistedHibernatableWebSocket) { + persistedHibernatableWebSocket.lastSeenTimestamp = BigInt( + Date.now(), + ); + persistedHibernatableWebSocket.msgIndex = BigInt( + event.rivetMessageIndex, + ); + } + + this.#rLog.debug({ + msg: "actor instance onSocketMessage", + rivetRequestId, + isHibernatable: !!persistedHibernatableWebSocket, + hibernationMsgIndex: + persistedHibernatableWebSocket?.msgIndex, + }); + }; + + const onSocketClosed = (_event: any) => { + // Remove hibernatable WS + if (rivetRequestId) { + const rivetRequestIdLocal = rivetRequestId; + const wsIndex = + this.#persist.hibernatableWebSocket.findIndex((ws) => + arrayBuffersEqual( + ws.requestId, + rivetRequestIdLocal, + ), + ); + + const removed = this.#persist.hibernatableWebSocket.splice( + wsIndex, + 1, + ); + if (removed.length > 0) { + this.#rLog.debug({ + msg: "removed hibernatable websocket", + rivetRequestId, + hibernationMsgIndex: + persistedHibernatableWebSocket?.msgIndex, + }); + } else { + this.#rLog.warn({ + msg: "could not find hibernatable websocket to remove", + rivetRequestId, + hibernationMsgIndex: + persistedHibernatableWebSocket?.msgIndex, + }); + } + } + + this.#rLog.debug({ + msg: "actor instance onSocketMessage", + rivetRequestId, + isHibernatable: !!persistedHibernatableWebSocket, + hibernatableWebSocketCount: + this.#persist.hibernatableWebSocket.length, + }); + // Remove listener and socket from tracking try { + websocket.removeEventListener("open", onSocketOpened); + websocket.removeEventListener("message", onSocketMessage); websocket.removeEventListener("close", onSocketClosed); websocket.removeEventListener("error", onSocketClosed); } catch {} this.#activeRawWebSockets.delete(websocket); this.#resetSleepTimer(); }; + try { + websocket.addEventListener("open", onSocketOpened); + websocket.addEventListener("message", onSocketMessage); websocket.addEventListener("close", onSocketClosed); websocket.addEventListener("error", onSocketClosed); } catch {} @@ -1794,6 +1901,10 @@ export class ActorInstance { // Check for active conns. This will also cover active actions, since all actions have a connection. for (const conn of this.#connections.values()) { + // TODO: Enable this when hibernation is implemented. We're waiting on support for Guard to not auto-wake the actor if it sleeps. + // if (conn.status === "connected" && !conn.isHibernatable) + // return false; + if (conn.status === "connected") return false; } @@ -1980,6 +2091,11 @@ export class ActorInstance { }, }, })), + hibernatableWebSocket: persist.hibernatableWebSocket.map((ws) => ({ + requestId: ws.requestId, + lastSeenTimestamp: ws.lastSeenTimestamp, + msgIndex: ws.msgIndex, + })), }; } @@ -2012,6 +2128,11 @@ export class ActorInstance { }, }, })), + hibernatableWebSocket: bareData.hibernatableWebSocket.map((ws) => ({ + requestId: ws.requestId, + lastSeenTimestamp: ws.lastSeenTimestamp, + msgIndex: ws.msgIndex, + })), }; } } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/persisted.ts b/rivetkit-typescript/packages/rivetkit/src/actor/persisted.ts index 40bfb66ffe..fb2203e8a1 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/persisted.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/persisted.ts @@ -5,6 +5,7 @@ export interface PersistedActor { state: S; connections: PersistedConn[]; scheduledEvents: PersistedScheduleEvent[]; + hibernatableWebSocket: PersistedHibernatableWebSocket[]; } /** Object representing connection that gets persisted to storage. */ @@ -37,3 +38,9 @@ export interface PersistedScheduleEvent { timestamp: number; kind: PersistedScheduleEventKind; } + +export interface PersistedHibernatableWebSocket { + requestId: ArrayBuffer; + lastSeenTimestamp: bigint; + msgIndex: bigint; +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts index 64f535f237..d8bb7128c0 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts @@ -12,7 +12,7 @@ import { } from "@/actor/conn"; import { ConnDriverKind } from "@/actor/conn-drivers"; import * as errors from "@/actor/errors"; -import type { AnyActorInstance } from "@/actor/instance"; +import { type AnyActorInstance, PERSIST_SYMBOL } from "@/actor/instance"; import type { InputData } from "@/actor/protocol/serde"; import { type Encoding, EncodingSchema } from "@/actor/protocol/serde"; import { @@ -38,7 +38,11 @@ import { deserializeWithEncoding, serializeWithEncoding, } from "@/serde"; -import { bufferToArrayBuffer, promiseWithResolvers } from "@/utils"; +import { + arrayBuffersEqual, + bufferToArrayBuffer, + promiseWithResolvers, +} from "@/utils"; import type { ActorDriver } from "./driver"; import { loggerWithoutContext } from "./log"; import { parseMessage } from "./protocol/old"; @@ -595,38 +599,38 @@ export async function handleRawWebSocketHandler( // Return WebSocket event handlers return { - onOpen: (_evt: any, ws: any) => { + onOpen: (evt: any, ws: any) => { + // Extract rivetRequestId provided by engine runner + const rivetRequestId = evt?.rivetRequestId; + const isHibernatable = + actor[PERSIST_SYMBOL].hibernatableWebSocket.findIndex((ws) => + arrayBuffersEqual(ws.requestId, rivetRequestId), + ) !== -1; + // Wrap the Hono WebSocket in our adapter - const adapter = new HonoWebSocketAdapter(ws); + const adapter = new HonoWebSocketAdapter( + ws, + rivetRequestId, + isHibernatable, + ); // Store adapter reference on the WebSocket for event handlers (ws as any).__adapter = adapter; - // Extract the path after prefix and preserve query parameters - // Use URL API for cleaner parsing - const url = new URL(path, "http://actor"); - const pathname = - url.pathname.replace(/^\/raw\/websocket\/?/, "") || "/"; - const normalizedPath = - (pathname.startsWith("/") ? pathname : "/" + pathname) + - url.search; - + const newPath = truncateRawWebSocketPathPrefix(path); let newRequest: Request; if (req) { - newRequest = new Request(`http://actor${normalizedPath}`, req); + newRequest = new Request(`http://actor${newPath}`, req); } else { - newRequest = new Request(`http://actor${normalizedPath}`, { + newRequest = new Request(`http://actor${newPath}`, { method: "GET", }); } actor.rLog.debug({ msg: "rewriting websocket url", - from: path, - to: newRequest.url, - pathname: url.pathname, - search: url.search, - normalizedPath, + fromPath: path, + toUrl: newRequest.url, }); // Call the actor's onWebSocket handler with the adapted WebSocket @@ -711,3 +715,22 @@ export function getRequestConnParams(req: HonoRequest): unknown { ); } } + +/** + * Truncase the PATH_RAW_WEBSOCKET_PREFIX path prefix in order to pass a clean + * path to the onWebSocket handler. + * + * Example: + * - `/raw/websocket/foo` -> `/foo` + * - `/raw/websocket` -> `/` + */ +export function truncateRawWebSocketPathPrefix(path: string): string { + // Extract the path after prefix and preserve query parameters + // Use URL API for cleaner parsing + const url = new URL(path, "http://actor"); + const pathname = url.pathname.replace(/^\/raw\/websocket\/?/, "") || "/"; + const normalizedPath = + (pathname.startsWith("/") ? pathname : "/" + pathname) + url.search; + + return normalizedPath; +} diff --git a/rivetkit-typescript/packages/rivetkit/src/common/websocket-interface.ts b/rivetkit-typescript/packages/rivetkit/src/common/websocket-interface.ts index 0d970fd803..47912ce44c 100644 --- a/rivetkit-typescript/packages/rivetkit/src/common/websocket-interface.ts +++ b/rivetkit-typescript/packages/rivetkit/src/common/websocket-interface.ts @@ -3,10 +3,20 @@ export interface RivetEvent { type: string; target?: any; currentTarget?: any; + /** + * @experimental + * Request ID for hibernatable websockets (provided by engine runner) + **/ + rivetRequestId?: ArrayBuffer; } export interface RivetMessageEvent extends RivetEvent { data: any; + /** + * @experimental + * Message index for hibernatable websockets (provided by engine runner) + **/ + rivetMessageIndex?: number; } export interface RivetCloseEvent extends RivetEvent { diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-helpers/utils.ts b/rivetkit-typescript/packages/rivetkit/src/driver-helpers/utils.ts index 8581cac70d..a75efdf0e4 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-helpers/utils.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-helpers/utils.ts @@ -15,6 +15,7 @@ export function serializeEmptyPersistData( state: bufferToArrayBuffer(cbor.encode(undefined)), connections: [], scheduledEvents: [], + hibernatableWebSocket: [], }; return PERSISTED_ACTOR_VERSIONED.serializeWithEmbeddedVersion(persistData); } diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts index 06b84689ff..508f0566ce 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -1,6 +1,7 @@ import type { ActorConfig as EngineActorConfig, RunnerConfig as EngineRunnerConfig, + HibernationConfig, } from "@rivetkit/engine-runner"; import { Runner } from "@rivetkit/engine-runner"; import * as cbor from "cbor-x"; @@ -9,12 +10,14 @@ import { streamSSE } from "hono/streaming"; import { WSContext } from "hono/ws"; import invariant from "invariant"; import { lookupInRegistry } from "@/actor/definition"; +import { PERSIST_SYMBOL } from "@/actor/instance"; import { deserializeActorKey } from "@/actor/keys"; import { EncodingSchema } from "@/actor/protocol/serde"; import { type ActorRouter, createActorRouter } from "@/actor/router"; import { handleRawWebSocketHandler, handleWebSocketConnect, + truncateRawWebSocketPathPrefix, } from "@/actor/router-endpoints"; import type { Client } from "@/client/client"; import { @@ -37,6 +40,7 @@ import { buildActorNames, type RegistryConfig } from "@/registry/config"; import type { RunnerConfig } from "@/registry/run-config"; import { getEndpoint } from "@/remote-manager-driver/api-utils"; import { + arrayBuffersEqual, type LongTimeoutHandle, promiseWithResolvers, setLongTimeout, @@ -137,6 +141,130 @@ export class EngineActorDriver implements ActorDriver { onActorStart: this.#runnerOnActorStart.bind(this), onActorStop: this.#runnerOnActorStop.bind(this), logger: getLogger("engine-runner"), + getActorHibernationConfig: ( + actorId: string, + requestId: ArrayBuffer, + request: Request, + ): HibernationConfig => { + const url = new URL(request.url); + const path = url.pathname; + + // Get actor instance from runner to access actor name + const actorInstance = this.#runner.getActor(actorId); + if (!actorInstance) { + logger().warn({ + msg: "actor not found in getActorHibernationConfig", + actorId, + }); + return { enabled: false, lastMsgIndex: undefined }; + } + + // Load actor handler to access persisted data + const handler = this.#actors.get(actorId); + if (!handler) { + logger().warn({ + msg: "actor handler not found in getActorHibernationConfig", + actorId, + }); + return { enabled: false, lastMsgIndex: undefined }; + } + if (!handler.actor) { + logger().warn({ + msg: "actor not found in getActorHibernationConfig", + actorId, + }); + return { enabled: false, lastMsgIndex: undefined }; + } + + // Check for existing WS + const existingWs = handler.actor[ + PERSIST_SYMBOL + ].hibernatableWebSocket.find((ws) => + arrayBuffersEqual(ws.requestId, requestId), + ); + + // Determine configuration for new WS + let hibernationConfig: HibernationConfig; + if (existingWs) { + hibernationConfig = { + enabled: true, + lastMsgIndex: Number(existingWs.msgIndex), + }; + } else { + if (path === PATH_CONNECT_WEBSOCKET) { + hibernationConfig = { + enabled: true, + lastMsgIndex: undefined, + }; + } else if (path.startsWith(PATH_RAW_WEBSOCKET_PREFIX)) { + // Find actor config + const definition = lookupInRegistry( + this.#registryConfig, + actorInstance.config.name, + ); + + // Check if can hibernate + const canHibernatWebSocket = + definition.config.options?.canHibernatWebSocket; + if (canHibernatWebSocket === true) { + hibernationConfig = { + enabled: true, + lastMsgIndex: undefined, + }; + } else if (typeof canHibernatWebSocket === "function") { + try { + // Truncate the path to match the behavior on onRawWebSocket + const newPath = truncateRawWebSocketPathPrefix( + url.pathname, + ); + const truncatedRequest = new Request( + `http://actor${newPath}`, + request, + ); + + const canHibernate = + canHibernatWebSocket(truncatedRequest); + hibernationConfig = { + enabled: canHibernate, + lastMsgIndex: undefined, + }; + } catch (error) { + logger().error({ + msg: "error calling canHibernatWebSocket", + error, + }); + hibernationConfig = { + enabled: false, + lastMsgIndex: undefined, + }; + } + } else { + hibernationConfig = { + enabled: false, + lastMsgIndex: undefined, + }; + } + } else { + logger().warn({ + msg: "unexpected path for getActorHibernationConfig", + path, + }); + hibernationConfig = { + enabled: false, + lastMsgIndex: undefined, + }; + } + } + + // Save hibernatable WebSocket + handler.actor[PERSIST_SYMBOL].hibernatableWebSocket.push({ + requestId, + lastSeenTimestamp: BigInt(Date.now()), + msgIndex: -1n, + }); + + return hibernationConfig; + }, }; // Create and start runner diff --git a/rivetkit-typescript/packages/rivetkit/src/manager/hono-websocket-adapter.ts b/rivetkit-typescript/packages/rivetkit/src/manager/hono-websocket-adapter.ts index ac6218e80b..6ceab6651d 100644 --- a/rivetkit-typescript/packages/rivetkit/src/manager/hono-websocket-adapter.ts +++ b/rivetkit-typescript/packages/rivetkit/src/manager/hono-websocket-adapter.ts @@ -23,16 +23,28 @@ export class HonoWebSocketAdapter implements UniversalWebSocket { #eventListeners: Map void>> = new Map(); #closeCode?: number; #closeReason?: string; - - constructor(ws: WSContext) { + readonly rivetRequestId?: ArrayBuffer; + readonly isHibernatable: boolean; + + constructor( + ws: WSContext, + rivetRequestId: ArrayBuffer | undefined, + isHibernatable: boolean, + ) { this.#ws = ws; + this.rivetRequestId = rivetRequestId; + this.isHibernatable = isHibernatable; // The WSContext is already open when we receive it this.#readyState = this.OPEN; // Immediately fire the open event setTimeout(() => { - this.#fireEvent("open", { type: "open", target: this }); + this.#fireEvent("open", { + type: "open", + target: this, + rivetRequestId: this.rivetRequestId, + }); }, 0); } @@ -155,6 +167,7 @@ export class HonoWebSocketAdapter implements UniversalWebSocket { code, reason, wasClean: code === 1000, + rivetRequestId: this.rivetRequestId, }); } catch (error) { logger().error({ msg: "error closing websocket", error }); @@ -165,6 +178,7 @@ export class HonoWebSocketAdapter implements UniversalWebSocket { code: 1006, reason: "Abnormal closure", wasClean: false, + rivetRequestId: this.rivetRequestId, }); } } @@ -204,6 +218,8 @@ export class HonoWebSocketAdapter implements UniversalWebSocket { _handleMessage(data: any): void { // Hono may pass either raw data or a MessageEvent-like object let messageData: string | ArrayBuffer | ArrayBufferView; + let rivetRequestId: ArrayBuffer | undefined; + let rivetMessageIndex: number | undefined; if (typeof data === "string") { messageData = data; @@ -212,6 +228,14 @@ export class HonoWebSocketAdapter implements UniversalWebSocket { } else if (data && typeof data === "object" && "data" in data) { // Handle MessageEvent-like objects messageData = data.data; + + // Preserve hibernation-related properties from engine runner + if ("rivetRequestId" in data) { + rivetRequestId = data.rivetRequestId; + } + if ("rivetMessageIndex" in data) { + rivetMessageIndex = data.rivetMessageIndex; + } } else { // Fallback - shouldn't happen in normal operation messageData = String(data); @@ -222,12 +246,15 @@ export class HonoWebSocketAdapter implements UniversalWebSocket { dataType: typeof messageData, isArrayBuffer: messageData instanceof ArrayBuffer, dataStr: typeof messageData === "string" ? messageData : "", + rivetMessageIndex, }); this.#fireEvent("message", { type: "message", target: this, data: messageData, + rivetRequestId, + rivetMessageIndex, }); } @@ -249,6 +276,7 @@ export class HonoWebSocketAdapter implements UniversalWebSocket { code, reason, wasClean: code === 1000, + rivetRequestId: this.rivetRequestId, }); } diff --git a/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/mod.ts b/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/mod.ts index 4e67d40235..e1afda50e2 100644 --- a/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/mod.ts @@ -1 +1 @@ -export * from "../../../dist/schemas/actor-persist/v1"; +export * from "../../../dist/schemas/actor-persist/v2"; diff --git a/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts b/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts index 98c2c37cb7..b6eaff8b3c 100644 --- a/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts +++ b/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts @@ -2,24 +2,36 @@ import { createVersionedDataHandler, type MigrationFn, } from "@/common/versioned-data"; -import * as v1 from "../../../dist/schemas/actor-persist/v1"; +import type * as v1 from "../../../dist/schemas/actor-persist/v1"; +import * as v2 from "../../../dist/schemas/actor-persist/v2"; -export const CURRENT_VERSION = 1; +export const CURRENT_VERSION = 2; -export type CurrentPersistedActor = v1.PersistedActor; -export type CurrentPersistedConnection = v1.PersistedConnection; -export type CurrentPersistedSubscription = v1.PersistedSubscription; +export type CurrentPersistedActor = v2.PersistedActor; +export type CurrentPersistedConnection = v2.PersistedConnection; +export type CurrentPersistedSubscription = v2.PersistedSubscription; export type CurrentGenericPersistedScheduleEvent = - v1.GenericPersistedScheduleEvent; -export type CurrentPersistedScheduleEventKind = v1.PersistedScheduleEventKind; -export type CurrentPersistedScheduleEvent = v1.PersistedScheduleEvent; + v2.GenericPersistedScheduleEvent; +export type CurrentPersistedScheduleEventKind = v2.PersistedScheduleEventKind; +export type CurrentPersistedScheduleEvent = v2.PersistedScheduleEvent; +export type CurrentPersistedHibernatableWebSocket = + v2.PersistedHibernatableWebSocket; const migrations = new Map>(); +// Migration from v1 to v2: Add hibernatableWebSocket field +migrations.set( + 1, + (v1Data: v1.PersistedActor): v2.PersistedActor => ({ + ...v1Data, + hibernatableWebSocket: [], + }), +); + export const PERSISTED_ACTOR_VERSIONED = createVersionedDataHandler({ currentVersion: CURRENT_VERSION, migrations, - serializeVersion: (data) => v1.encodePersistedActor(data), - deserializeVersion: (bytes) => v1.decodePersistedActor(bytes), + serializeVersion: (data) => v2.encodePersistedActor(data), + deserializeVersion: (bytes) => v2.decodePersistedActor(bytes), }); diff --git a/rivetkit-typescript/packages/rivetkit/src/utils.ts b/rivetkit-typescript/packages/rivetkit/src/utils.ts index 6de0a0a168..6f788637bb 100644 --- a/rivetkit-typescript/packages/rivetkit/src/utils.ts +++ b/rivetkit-typescript/packages/rivetkit/src/utils.ts @@ -248,3 +248,18 @@ export function combineUrlPath( const fullQuery = queryParts.length > 0 ? `?${queryParts.join("&")}` : ""; return `${baseUrl.protocol}//${baseUrl.host}${fullPath}${fullQuery}`; } + +export function arrayBuffersEqual( + buf1: ArrayBuffer, + buf2: ArrayBuffer, +): boolean { + if (buf1.byteLength !== buf2.byteLength) return false; + + const view1 = new Uint8Array(buf1); + const view2 = new Uint8Array(buf2); + + for (let i = 0; i < view1.length; i++) { + if (view1[i] !== view2[i]) return false; + } + return true; +}