Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 85 additions & 3 deletions core/llm/llms/Replicate.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ReplicateClient from "replicate";

import { CompletionOptions, LLMOptions } from "../../index.js";
import { ChatMessage, CompletionOptions, LLMOptions } from "../../index.js";
import { renderChatMessage } from "../../util/messageContent.js";
import { BaseLLM } from "../index.js";

class Replicate extends BaseLLM {
Expand Down Expand Up @@ -33,6 +34,7 @@ class Replicate extends BaseLLM {
"tomasmcm/neural-chat-7b-v3-1:acb450496b49e19a1e410b50c574a34acacd54820bc36c19cbfe05148de2ba57",
"deepseek-7b": "kcaverly/deepseek-coder-33b-instruct-gguf" as any,
"phind-codellama-34b": "kcaverly/phind-codellama-34b-v2-gguf" as any,
"claude-4-sonnet-latest": "anthropic/claude-4-sonnet" as any,
};

static providerName = "replicate";
Expand All @@ -42,12 +44,55 @@ class Replicate extends BaseLLM {
options: CompletionOptions,
prompt: string,
signal: AbortSignal,
): [`${string}/${string}:${string}`, { input: any; signal: AbortSignal }] {
): [`${string}/${string}:${string}`, { input: any }] {
return [
Replicate.MODEL_IDS[options.model] || (options.model as any),
{
input: { prompt, message: prompt },
signal,
},
];
}

private _convertChatArgs(
options: CompletionOptions,
messages: ChatMessage[],
signal: AbortSignal,
): [`${string}/${string}:${string}`, { input: any }] {
let prompt = "";
let system_prompt = "";

for (const message of messages) {
const content =
typeof message.content === "string"
? message.content
: renderChatMessage(message);

if (message.role === "system") {
system_prompt += `System: ${content}\n\n`;
} else if (message.role === "user") {
prompt += `Human: ${content}\n\n`;
} else if (message.role === "assistant") {
prompt += `Assistant: ${content}\n\n`;
}
}

if (!prompt.endsWith("Assistant: ")) {
prompt += "Assistant: ";
}

// Construct the input
const input: any = {
prompt: prompt,
system_prompt: system_prompt,
max_tokens: options.maxTokens || 2048,
extended_thinking: options.reasoning,
thinking_budget_tokens: options.reasoningBudgetTokens || 1024,
};

return [
Replicate.MODEL_IDS[options.model] || (options.model as any),
{
input,
},
];
}
Expand Down Expand Up @@ -80,6 +125,43 @@ class Replicate extends BaseLLM {
}
}
}

protected async *_streamChat(
messages: ChatMessage[],
signal: AbortSignal,
options: CompletionOptions,
): AsyncGenerator<ChatMessage> {
if (!this.apiKey || this.apiKey === "") {
throw new Error("You need to use an API key");
}

const [model, args] = this._convertChatArgs(options, messages, signal);

try {
for await (const event of this._replicate.stream(model, args)) {
if (event.event === "output") {
yield {
role: "assistant",
content: event.data,
};
}
}
} catch (error) {
if (error instanceof Error) {
if (error.message.includes("authentication")) {
throw new Error(
"Replicate API authentication failed. Please check your API key",
);
}
if (error.message.includes("model not found")) {
throw new Error(
`Model "${options.model}" not found on Replicate. Please check the model name or use another model.`,
);
}
}
throw error;
}
}
}

export default Replicate;
2 changes: 1 addition & 1 deletion gui/src/pages/AddNewModel/configs/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1155,7 +1155,7 @@ export const models: { [key: string]: ModelPackage } = {
title: "Claude 4 Sonnet",
apiKey: "",
},
providerOptions: ["anthropic"],
providerOptions: ["anthropic", "replicate"],
icon: "anthropic.png",
isOpenSource: false,
},
Expand Down
1 change: 1 addition & 0 deletions gui/src/pages/AddNewModel/configs/providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ Select the \`GPT-4o\` model below to complete your provider configuration, but n
models.codeLlamaInstruct,
models.wizardCoder,
models.mistralOs,
models.claude4Sonnet,
],
apiKeyUrl: "https://replicate.com/account/api-tokens",
},
Expand Down
Loading