Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion engine/sdks/typescript/runner/src/mod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export interface RunnerConfig {
config: ActorConfig,
) => Promise<void>;
onActorStop: (actorId: string, generation: number) => Promise<void>;
getActorHibernationConfig: (actorId: string, requestId: ArrayBuffer) => HibernationConfig;
getActorHibernationConfig: (actorId: string, requestId: ArrayBuffer, request: Request) => HibernationConfig;
noAutoShutdown?: boolean;
}

Expand Down
35 changes: 19 additions & 16 deletions engine/sdks/typescript/runner/src/tunnel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -541,21 +541,10 @@ export class Tunnel {
// Store adapter
this.#actorWebSockets.set(webSocketId, adapter);

// Send open confirmation
let hibernationConfig = this.#runner.config.getActorHibernationConfig(actor.actorId, requestId);
this.#sendMessage(requestId, {
tag: "ToServerWebSocketOpen",
val: {
canHibernate: hibernationConfig.enabled,
lastMsgIndex: BigInt(hibernationConfig.lastMsgIndex ?? -1),
},
});

// Notify adapter that connection is open
adapter._handleOpen(requestId);

// Create a minimal request object for the websocket handler
// Include original headers from the open message
// Convert headers to map
//
// We need to manually ensure the original Upgrade/Connection WS
// headers are present
const headerInit: Record<string, string> = {};
if (open.headers) {
for (const [k, v] of open.headers as ReadonlyMap<
Expand All @@ -565,7 +554,6 @@ export class Tunnel {
headerInit[k] = v;
}
}
// Ensure websocket upgrade headers are present
headerInit["Upgrade"] = "websocket";
headerInit["Connection"] = "Upgrade";

Expand All @@ -574,6 +562,21 @@ export class Tunnel {
headers: headerInit,
});

// Send open confirmation
let hibernationConfig = this.#runner.config.getActorHibernationConfig(actor.actorId, requestId, request);
this.#sendMessage(requestId, {
tag: "ToServerWebSocketOpen",
val: {
canHibernate: hibernationConfig.enabled,
lastMsgIndex: BigInt(hibernationConfig.lastMsgIndex ?? -1),
},
});

// Notify adapter that connection is open
adapter._handleOpen(requestId);



// Call websocket handler
await websocketHandler(
this.#runner,
Expand Down
2 changes: 1 addition & 1 deletion rivetkit-openapi/openapi.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"openapi": "3.0.0",
"info": {
"version": "2.0.21",
"version": "2.0.22-rc.1",
"title": "RivetKit API"
},
"components": {
Expand Down
2 changes: 1 addition & 1 deletion rivetkit-typescript/packages/rivetkit/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@
],
"scripts": {
"build": "tsup src/mod.ts src/client/mod.ts src/common/log.ts src/common/websocket.ts src/actor/errors.ts src/topologies/coordinate/mod.ts src/topologies/partition/mod.ts src/utils.ts src/driver-helpers/mod.ts src/driver-test-suite/mod.ts src/test/mod.ts src/inspector/mod.ts",
"build:schema": "./scripts/compile-bare.ts compile schemas/client-protocol/v1.bare -o dist/schemas/client-protocol/v1.ts && ./scripts/compile-bare.ts compile schemas/file-system-driver/v1.bare -o dist/schemas/file-system-driver/v1.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v1.bare -o dist/schemas/actor-persist/v1.ts",
"build:schema": "./scripts/compile-bare.ts compile schemas/client-protocol/v1.bare -o dist/schemas/client-protocol/v1.ts && ./scripts/compile-bare.ts compile schemas/file-system-driver/v1.bare -o dist/schemas/file-system-driver/v1.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v1.bare -o dist/schemas/actor-persist/v1.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v2.bare -o dist/schemas/actor-persist/v2.ts",
"check-types": "tsc --noEmit",
"test": "vitest run",
"test:watch": "vitest",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# MARK: Connection
# Represents an event subscription.
type PersistedSubscription struct {
# Event name
eventName: str
}

type PersistedConnection struct {
id: str
token: str
parameters: data
state: data
subscriptions: list<PersistedSubscription>
lastSeen: u64
}

# MARK: Schedule Event
type GenericPersistedScheduleEvent struct {
# Action name
action: str
# Arguments for the action
#
# CBOR array
args: optional<data>
}

type PersistedScheduleEventKind union {
GenericPersistedScheduleEvent
}

type PersistedScheduleEvent struct {
eventId: str
timestamp: u64
kind: PersistedScheduleEventKind
}

# MARK: WebSocket
type PersistedHibernatableWebSocket struct {
requestId: data
lastSeenTimestamp: u64
msgIndex: u64
}

# MARK: Actor
# Represents the persisted state of an actor.
type PersistedActor struct {
# Input data passed to the actor on initialization
input: optional<data>
hasInitialized: bool
state: data
connections: list<PersistedConnection>
scheduledEvents: list<PersistedScheduleEvent>
hibernatableWebSocket: list<PersistedHibernatableWebSocket>
}
10 changes: 10 additions & 0 deletions rivetkit-typescript/packages/rivetkit/src/actor/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ export const ActorConfigSchema = z
connectionLivenessInterval: z.number().positive().default(5000),
noSleep: z.boolean().default(false),
sleepTimeout: z.number().positive().default(30_000),
/** @experimental */
canHibernatWebSocket: z
.union([
z.boolean(),
z
.function()
.args(z.custom<Request>())
.returns(z.boolean()),
])
.default(false),
})
.strict()
.default({}),
Expand Down
33 changes: 33 additions & 0 deletions rivetkit-typescript/packages/rivetkit/src/actor/conn-drivers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import type { AnyConn } from "@/actor/conn";
import type { AnyActorInstance } from "@/actor/instance";
import type { CachedSerializer, Encoding } from "@/actor/protocol/serde";
import { encodeDataToString } from "@/actor/protocol/serde";
import type { HonoWebSocketAdapter } from "@/manager/hono-websocket-adapter";
import type * as protocol from "@/schemas/client-protocol/mod";
import { assertUnreachable, type promiseWithResolvers } from "@/utils";

Expand Down Expand Up @@ -67,6 +68,15 @@ export interface ConnDriver<State> {
conn: AnyConn,
state: State,
): ConnReadyState | undefined;

/**
* If the underlying connection can hibernate.
*/
isHibernatable(
actor: AnyActorInstance,
conn: AnyConn,
state: State,
): boolean;
}

// MARK: WebSocket
Expand Down Expand Up @@ -140,6 +150,22 @@ const WEBSOCKET_DRIVER: ConnDriver<ConnDriverWebSocketState> = {
): ConnReadyState | undefined => {
return state.websocket.readyState;
},

isHibernatable(
_actor: AnyActorInstance,
_conn: AnyConn,
state: ConnDriverWebSocketState,
): boolean {
// Extract isHibernatable from the HonoWebSocketAdapter
if (state.websocket.raw) {
const raw = state.websocket.raw as HonoWebSocketAdapter;
if (typeof raw.isHibernatable === "boolean") {
return raw.isHibernatable;
}
}

return false;
},
};

// MARK: SSE
Expand Down Expand Up @@ -175,6 +201,10 @@ const SSE_DRIVER: ConnDriver<ConnDriverSseState> = {

return ConnReadyState.OPEN;
},

isHibernatable(): boolean {
return false;
},
};

// MARK: HTTP
Expand All @@ -187,6 +217,9 @@ const HTTP_DRIVER: ConnDriver<ConnDriverHttpState> = {
// Noop
// TODO: Abort the request
},
isHibernatable(): boolean {
return false;
},
};

/** List of all connection drivers. */
Expand Down
20 changes: 20 additions & 0 deletions rivetkit-typescript/packages/rivetkit/src/actor/conn.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import * as cbor from "cbor-x";
import invariant from "invariant";
import { PersistedHibernatableWebSocket } from "@/schemas/actor-persist/mod";
import type * as protocol from "@/schemas/client-protocol/mod";
import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned";
import { bufferToArrayBuffer } from "@/utils";
Expand Down Expand Up @@ -125,6 +126,25 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
return this.__status;
}

/**
* @experimental
*
* If the underlying connection can hibernate.
*/
public get isHibernatable(): boolean {
if (this.__driverState) {
const driverKind = getConnDriverKindFromState(this.__driverState);
const driver = CONN_DRIVERS[driverKind];
return driver.isHibernatable(
this.#actor,
this,
(this.__driverState as any)[driverKind],
);
} else {
return false;
}
}

/**
* Timestamp of the last time the connection was seen, i.e. the last time the connection was active and checked for liveness.
*/
Expand Down
Loading
Loading