From 268ab5010b07a0469b0b1ae67fc81371e449f9ac Mon Sep 17 00:00:00 2001 From: Gustavo Cid Ornelas Date: Mon, 23 Sep 2024 15:36:47 -0300 Subject: [PATCH] feat: introduce the OpenlayerHandler, which implements the LangChain callback handler interface --- src/lib/integrations/langchainCallback.ts | 141 ++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 src/lib/integrations/langchainCallback.ts diff --git a/src/lib/integrations/langchainCallback.ts b/src/lib/integrations/langchainCallback.ts new file mode 100644 index 00000000..fbd07ce4 --- /dev/null +++ b/src/lib/integrations/langchainCallback.ts @@ -0,0 +1,141 @@ +import { BaseCallbackHandler } from '@langchain/core/callbacks/base'; +import { LLMResult } from '@langchain/core/dist/outputs'; +import type { Serialized } from '@langchain/core/load/serializable'; +import { AIMessage, BaseMessage, SystemMessage } from '@langchain/core/messages'; +import { addChatCompletionStepToTrace } from '../tracing/tracer'; + +const LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP: Record = { + openai: 'OpenAI', + 'openai-chat': 'OpenAI', + 'chat-ollama': 'Ollama', + vertexai: 'Google', +}; +const PROVIDER_TO_STEP_NAME: Record = { + OpenAI: 'OpenAI Chat Completion', + Ollama: 'Ollama Chat Completion', + Google: 'Google Vertex AI Chat Completion', +}; + +export class OpenlayerHandler extends BaseCallbackHandler { + name = 'OpenlayerHandler'; + startTime: number | null = null; + endTime: number | null = null; + prompt: Array<{ role: string; content: string }> | null = null; + latency: number = 0; + provider: string | undefined; + model: string | null = null; + modelParameters: Record | null = null; + promptTokens: number | null = 0; + completionTokens: number | null = 0; + totalTokens: number | null = 0; + output: string = ''; + metadata: Record; + + constructor(kwargs: Record = {}) { + super(); + this.metadata = kwargs; + } + override async handleChatModelStart( + llm: Serialized, + messages: BaseMessage[][], + runId: string, + parentRunId?: string | undefined, + extraParams?: Record | undefined, + tags?: string[] | undefined, + metadata?: Record | undefined, + name?: string, + ): Promise { + this.initializeRun(extraParams || {}, metadata || {}); + this.prompt = this.langchainMassagesToPrompt(messages); + this.startTime = performance.now(); + } + + private initializeRun(extraParams: Record, metadata: Record): void { + this.modelParameters = extraParams['invocation_params'] || {}; + + const provider = metadata?.['ls_provider'] as string; + if (provider && LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP[provider]) { + this.provider = LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP[provider]; + } + this.model = (this.modelParameters?.['model'] as string) || (metadata['ls_model_name'] as string) || null; + this.output = ''; + } + + private langchainMassagesToPrompt(messages: BaseMessage[][]): Array<{ role: string; content: string }> { + let prompt: Array<{ role: string; content: string }> = []; + for (const message of messages) { + for (const m of message) { + if (m instanceof AIMessage) { + prompt.push({ role: 'assistant', content: m.content as string }); + } else if (m instanceof SystemMessage) { + prompt.push({ role: 'system', content: m.content as string }); + } else { + prompt.push({ role: 'user', content: m.content as string }); + } + } + } + return prompt; + } + + override async handleLLMStart( + llm: Serialized, + prompts: string[], + runId: string, + parentRunId?: string, + extraParams?: Record, + tags?: string[], + metadata?: Record, + runName?: string, + ) { + this.initializeRun(extraParams || {}, metadata || {}); + this.prompt = prompts.map((p) => ({ role: 'user', content: p })); + this.startTime = performance.now(); + } + + override async handleLLMEnd(output: LLMResult, runId: string, parentRunId?: string, tags?: string[]) { + this.endTime = performance.now(); + this.latency = this.endTime - this.startTime!; + this.extractTokenInformation(output); + this.extractOutput(output); + this.addToTrace(); + } + + private extractTokenInformation(output: LLMResult) { + if (this.provider === 'OpenAI') { + this.openaiTokenInformation(output); + } + } + + private openaiTokenInformation(output: LLMResult) { + if (output.llmOutput && 'tokenUsage' in output.llmOutput) { + this.promptTokens = output.llmOutput?.['tokenUsage']?.promptTokens ?? 0; + this.completionTokens = output.llmOutput?.['tokenUsage']?.completionTokens ?? 0; + this.totalTokens = output.llmOutput?.['tokenUsage']?.totalTokens ?? 0; + } + } + + private extractOutput(output: LLMResult) { + const lastResponse = output?.generations?.at(-1)?.at(-1) ?? undefined; + this.output += lastResponse?.text ?? ''; + } + + private addToTrace() { + let name = 'Chat Completion Model'; + if (this.provider && this.provider in PROVIDER_TO_STEP_NAME) { + name = PROVIDER_TO_STEP_NAME[this.provider] ?? 'Chat Completion Model'; + } + addChatCompletionStepToTrace({ + name: name, + inputs: { prompt: this.prompt }, + output: this.output, + latency: this.latency, + tokens: this.totalTokens, + promptTokens: this.promptTokens, + completionTokens: this.completionTokens, + model: this.model, + modelParameters: this.modelParameters, + metadata: this.metadata, + provider: this.provider ?? '', + }); + } +}