diff --git a/core/llm/llms/Replicate.ts b/core/llm/llms/Replicate.ts index 44a15c61d51..ba4b44fb605 100644 --- a/core/llm/llms/Replicate.ts +++ b/core/llm/llms/Replicate.ts @@ -1,6 +1,7 @@ import ReplicateClient from "replicate"; -import { CompletionOptions, LLMOptions } from "../../index.js"; +import { ChatMessage, CompletionOptions, LLMOptions } from "../../index.js"; +import { renderChatMessage } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; class Replicate extends BaseLLM { @@ -33,6 +34,7 @@ class Replicate extends BaseLLM { "tomasmcm/neural-chat-7b-v3-1:acb450496b49e19a1e410b50c574a34acacd54820bc36c19cbfe05148de2ba57", "deepseek-7b": "kcaverly/deepseek-coder-33b-instruct-gguf" as any, "phind-codellama-34b": "kcaverly/phind-codellama-34b-v2-gguf" as any, + "claude-4-sonnet-latest": "anthropic/claude-4-sonnet" as any, }; static providerName = "replicate"; @@ -42,12 +44,55 @@ class Replicate extends BaseLLM { options: CompletionOptions, prompt: string, signal: AbortSignal, - ): [`${string}/${string}:${string}`, { input: any; signal: AbortSignal }] { + ): [`${string}/${string}:${string}`, { input: any }] { return [ Replicate.MODEL_IDS[options.model] || (options.model as any), { input: { prompt, message: prompt }, - signal, + }, + ]; + } + + private _convertChatArgs( + options: CompletionOptions, + messages: ChatMessage[], + signal: AbortSignal, + ): [`${string}/${string}:${string}`, { input: any }] { + let prompt = ""; + let system_prompt = ""; + + for (const message of messages) { + const content = + typeof message.content === "string" + ? message.content + : renderChatMessage(message); + + if (message.role === "system") { + system_prompt += `System: ${content}\n\n`; + } else if (message.role === "user") { + prompt += `Human: ${content}\n\n`; + } else if (message.role === "assistant") { + prompt += `Assistant: ${content}\n\n`; + } + } + + if (!prompt.endsWith("Assistant: ")) { + prompt += "Assistant: "; + } + + // Construct the input + const input: any = { + prompt: prompt, + system_prompt: system_prompt, + max_tokens: options.maxTokens || 2048, + extended_thinking: options.reasoning, + thinking_budget_tokens: options.reasoningBudgetTokens || 1024, + }; + + return [ + Replicate.MODEL_IDS[options.model] || (options.model as any), + { + input, }, ]; } @@ -80,6 +125,43 @@ class Replicate extends BaseLLM { } } } + + protected async *_streamChat( + messages: ChatMessage[], + signal: AbortSignal, + options: CompletionOptions, + ): AsyncGenerator { + if (!this.apiKey || this.apiKey === "") { + throw new Error("You need to use an API key"); + } + + const [model, args] = this._convertChatArgs(options, messages, signal); + + try { + for await (const event of this._replicate.stream(model, args)) { + if (event.event === "output") { + yield { + role: "assistant", + content: event.data, + }; + } + } + } catch (error) { + if (error instanceof Error) { + if (error.message.includes("authentication")) { + throw new Error( + "Replicate API authentication failed. Please check your API key", + ); + } + if (error.message.includes("model not found")) { + throw new Error( + `Model "${options.model}" not found on Replicate. Please check the model name or use another model.`, + ); + } + } + throw error; + } + } } export default Replicate; diff --git a/gui/src/pages/AddNewModel/configs/models.ts b/gui/src/pages/AddNewModel/configs/models.ts index aaa8092f749..07cac28fa2f 100644 --- a/gui/src/pages/AddNewModel/configs/models.ts +++ b/gui/src/pages/AddNewModel/configs/models.ts @@ -1155,7 +1155,7 @@ export const models: { [key: string]: ModelPackage } = { title: "Claude 4 Sonnet", apiKey: "", }, - providerOptions: ["anthropic"], + providerOptions: ["anthropic", "replicate"], icon: "anthropic.png", isOpenSource: false, }, diff --git a/gui/src/pages/AddNewModel/configs/providers.ts b/gui/src/pages/AddNewModel/configs/providers.ts index 88e091ce775..e5e8c1d7f86 100644 --- a/gui/src/pages/AddNewModel/configs/providers.ts +++ b/gui/src/pages/AddNewModel/configs/providers.ts @@ -725,6 +725,7 @@ Select the \`GPT-4o\` model below to complete your provider configuration, but n models.codeLlamaInstruct, models.wizardCoder, models.mistralOs, + models.claude4Sonnet, ], apiKeyUrl: "https://replicate.com/account/api-tokens", },