Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 79 additions & 39 deletions engine/sdks/typescript/runner/src/tunnel.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type * as protocol from "@rivetkit/engine-runner-protocol";
import type { MessageId, RequestId } from "@rivetkit/engine-runner-protocol";
import { v4 as uuidv4, stringify as uuidstringify } from "uuid";
import { stringify as uuidstringify, v4 as uuidv4 } from "uuid";
import { logger } from "./log";
import type { ActorInstance, Runner } from "./mod";
import { unreachable } from "./utils";
Expand Down Expand Up @@ -77,12 +77,20 @@ export class Tunnel {
// Build message
const messageId = generateUuidBuffer();

const requestIdStr = bufferToString(requestId);
this.#pendingTunnelMessages.set(bufferToString(messageId), {
const requestIdStr = idToStr(requestId);
const messageIdStr = idToStr(messageId);
this.#pendingTunnelMessages.set(messageIdStr, {
sentAt: Date.now(),
requestIdStr,
});

logger()?.debug({
msg: "send tunnel msg",
requestId: requestIdStr,
messageId: messageIdStr,
message: messageKind,
});

// Send message
const message: protocol.ToServer = {
tag: "ToServerTunnelMessage",
Expand Down Expand Up @@ -111,8 +119,8 @@ export class Tunnel {

logger()?.debug({
msg: "ack tunnel msg",
requestId: uuidstringify(new Uint8Array(requestId)),
messageId: uuidstringify(new Uint8Array(messageId)),
requestId: idToStr(requestId),
messageId: idToStr(messageId),
});

this.#runner.__sendToServer(message);
Expand Down Expand Up @@ -163,7 +171,10 @@ export class Tunnel {
const webSocket = this.#actorWebSockets.get(requestIdStr);
if (webSocket) {
// Close the WebSocket connection
webSocket.__closeWithRetry(1000, "Message acknowledgment timeout");
webSocket.__closeWithRetry(
1000,
"Message acknowledgment timeout",
);

// Clean up from actorWebSockets map
this.#actorWebSockets.delete(requestIdStr);
Expand Down Expand Up @@ -207,7 +218,11 @@ export class Tunnel {
actor.webSockets.clear();
}

async #fetch(actorId: string, requestId: protocol.RequestId, request: Request): Promise<Response> {
async #fetch(
actorId: string,
requestId: protocol.RequestId,
request: Request,
): Promise<Response> {
// Validate actor exists
if (!this.#runner.hasActor(actorId)) {
logger()?.warn({
Expand All @@ -219,7 +234,10 @@ export class Tunnel {
//
// See should_retry_request_inner
// https://github.com/rivet-dev/rivet/blob/222dae87e3efccaffa2b503de40ecf8afd4e31eb/engine/packages/guard-core/src/proxy_service.rs#L2458
return new Response("Actor not found", { status: 503, headers: { "x-rivet-error": "runner.actor_not_found" } });
return new Response("Actor not found", {
status: 503,
headers: { "x-rivet-error": "runner.actor_not_found" },
});
}

const fetchHandler = this.#runner.config.fetch(
Expand All @@ -237,19 +255,28 @@ export class Tunnel {
}

async handleTunnelMessage(message: protocol.ToClientTunnelMessage) {
const requestIdStr = idToStr(message.requestId);
const messageIdStr = idToStr(message.messageId);
logger()?.debug({
msg: "tunnel msg",
requestId: uuidstringify(new Uint8Array(message.requestId)),
messageId: uuidstringify(new Uint8Array(message.messageId)),
msg: "receive tunnel msg",
requestId: requestIdStr,
messageId: messageIdStr,
message: message.messageKind,
});

if (message.messageKind.tag === "TunnelAck") {
// Mark pending message as acknowledged and remove it
const msgIdStr = bufferToString(message.messageId);
const pending = this.#pendingTunnelMessages.get(msgIdStr);
const pending = this.#pendingTunnelMessages.get(messageIdStr);
if (pending) {
this.#pendingTunnelMessages.delete(msgIdStr);
const didDelete =
this.#pendingTunnelMessages.delete(messageIdStr);
if (!didDelete) {
logger()?.warn({
msg: "received tunnel ack for nonexistent message",
requestId: requestIdStr,
messageId: messageIdStr,
});
}
}
} else {
switch (message.messageKind.tag) {
Expand Down Expand Up @@ -282,14 +309,15 @@ export class Tunnel {
message.messageKind.val,
);
break;
case "ToClientWebSocketMessage":
case "ToClientWebSocketMessage": {
this.#sendAck(message.requestId, message.messageId);

let _unhandled = await this.#handleWebSocketMessage(
const _unhandled = await this.#handleWebSocketMessage(
message.requestId,
message.messageKind.val,
);
break;
}
case "ToClientWebSocketClose":
this.#sendAck(message.requestId, message.messageId);

Expand All @@ -309,7 +337,7 @@ export class Tunnel {
req: protocol.ToClientRequestStart,
) {
// Track this request for the actor
const requestIdStr = bufferToString(requestId);
const requestIdStr = idToStr(requestId);
const actor = this.#runner.getActor(req.actorId);
if (actor) {
actor.requests.add(requestIdStr);
Expand Down Expand Up @@ -342,8 +370,8 @@ export class Tunnel {
existing.actorId = req.actorId;
} else {
this.#actorPendingRequests.set(requestIdStr, {
resolve: () => { },
reject: () => { },
resolve: () => {},
reject: () => {},
streamController: controller,
actorId: req.actorId,
});
Expand All @@ -366,7 +394,11 @@ export class Tunnel {
await this.#sendResponse(requestId, response);
} else {
// Non-streaming request
const response = await this.#fetch(req.actorId, requestId, request);
const response = await this.#fetch(
req.actorId,
requestId,
request,
);
await this.#sendResponse(requestId, response);
}
} catch (error) {
Expand All @@ -385,7 +417,7 @@ export class Tunnel {
requestId: ArrayBuffer,
chunk: protocol.ToClientRequestChunk,
) {
const requestIdStr = bufferToString(requestId);
const requestIdStr = idToStr(requestId);
const pending = this.#actorPendingRequests.get(requestIdStr);
if (pending?.streamController) {
pending.streamController.enqueue(new Uint8Array(chunk.body));
Expand All @@ -397,7 +429,7 @@ export class Tunnel {
}

async #handleRequestAbort(requestId: ArrayBuffer) {
const requestIdStr = bufferToString(requestId);
const requestIdStr = idToStr(requestId);
const pending = this.#actorPendingRequests.get(requestIdStr);
if (pending?.streamController) {
pending.streamController.error(new Error("Request aborted"));
Expand Down Expand Up @@ -461,7 +493,7 @@ export class Tunnel {
requestId: protocol.RequestId,
open: protocol.ToClientWebSocketOpen,
) {
const webSocketId = bufferToString(requestId);
const webSocketId = idToStr(requestId);
// Validate actor exists
const actor = this.#runner.getActor(open.actorId);
if (!actor) {
Expand Down Expand Up @@ -518,7 +550,7 @@ export class Tunnel {
const dataBuffer =
typeof data === "string"
? (new TextEncoder().encode(data)
.buffer as ArrayBuffer)
.buffer as ArrayBuffer)
: data;

this.#sendMessage(requestId, {
Expand Down Expand Up @@ -575,7 +607,12 @@ export class Tunnel {
});

// Send open confirmation
let hibernationConfig = this.#runner.config.getActorHibernationConfig(actor.actorId, requestId, request);
const hibernationConfig =
this.#runner.config.getActorHibernationConfig(
actor.actorId,
requestId,
request,
);
this.#sendMessage(requestId, {
tag: "ToServerWebSocketOpen",
val: {
Expand All @@ -587,8 +624,6 @@ export class Tunnel {
// Notify adapter that connection is open
adapter._handleOpen(requestId);



// Call websocket handler
await websocketHandler(
this.#runner,
Expand Down Expand Up @@ -623,14 +658,19 @@ export class Tunnel {
requestId: ArrayBuffer,
msg: protocol.ToClientWebSocketMessage,
): Promise<boolean> {
const webSocketId = bufferToString(requestId);
const webSocketId = idToStr(requestId);
const adapter = this.#actorWebSockets.get(webSocketId);
if (adapter) {
const data = msg.binary
? new Uint8Array(msg.data)
: new TextDecoder().decode(new Uint8Array(msg.data));

return adapter._handleMessage(requestId, data, msg.index, msg.binary);
return adapter._handleMessage(
requestId,
data,
msg.index,
msg.binary,
);
} else {
return true;
}
Expand All @@ -639,11 +679,12 @@ export class Tunnel {
__ackWebsocketMessage(requestId: ArrayBuffer, index: number) {
logger()?.debug({
msg: "ack ws msg",
requestId: uuidstringify(new Uint8Array(requestId)),
requestId: idToStr(requestId),
index,
});

if (index < 0 || index > 65535) throw new Error("invalid websocket ack index");
if (index < 0 || index > 65535)
throw new Error("invalid websocket ack index");

// Send the ack message
this.#sendMessage(requestId, {
Expand All @@ -658,27 +699,26 @@ export class Tunnel {
requestId: ArrayBuffer,
close: protocol.ToClientWebSocketClose,
) {
const webSocketId = bufferToString(requestId);
const adapter = this.#actorWebSockets.get(webSocketId);
const requestIdStr = idToStr(requestId);
const adapter = this.#actorWebSockets.get(requestIdStr);
if (adapter) {
adapter._handleClose(
requestId,
close.code || undefined,
close.reason || undefined,
);
this.#actorWebSockets.delete(webSocketId);
this.#actorWebSockets.delete(requestIdStr);
}
}
}

/** Converts a buffer to a string. Used for storing strings in a lookup map. */
function bufferToString(buffer: ArrayBuffer): string {
return Buffer.from(buffer).toString("base64");
}

/** Generates a UUID as bytes. */
function generateUuidBuffer(): ArrayBuffer {
const buffer = new Uint8Array(16);
uuidv4(undefined, buffer);
return buffer.buffer;
}

function idToStr(id: ArrayBuffer): string {
return uuidstringify(new Uint8Array(id));
}
Loading
Loading