diff --git a/Dockerfile b/Dockerfile index 88599e47..ab277b95 100644 --- a/Dockerfile +++ b/Dockerfile @@ -33,6 +33,9 @@ COPY --from=builder /app/config.json ./ COPY --from=builder /app/dist ./dist # Set environment variables (Recommended to set at runtime, avoid hardcoding) +ENV AZURE_OPENAI_RESOURCE_NAME=${AZURE_OPENAI_RESOURCE_NAME} +ENV AZURE_OPENAI_API_KEY=${AZURE_OPENAI_API_KEY} +ENV AZURE_OPENAI_API_VERSION=${AZURE_OPENAI_API_VERSION} ENV GEMINI_API_KEY=${GEMINI_API_KEY} ENV OPENAI_API_KEY=${OPENAI_API_KEY} ENV JINA_API_KEY=${JINA_API_KEY} diff --git a/config.json b/config.json index d963fbd8..66c39016 100644 --- a/config.json +++ b/config.json @@ -7,7 +7,10 @@ "JINA_API_KEY": "", "BRAVE_API_KEY": "", "SERPER_API_KEY": "", - "DEFAULT_MODEL_NAME": "" + "DEFAULT_MODEL_NAME": "", + "AZURE_OPENAI_RESOURCE_NAME": "", + "AZURE_OPENAI_API_KEY": "", + "AZURE_OPENAI_API_VERSION": "" }, "defaults": { "search_provider": "jina", @@ -23,6 +26,9 @@ "clientConfig": { "compatibility": "strict" } + }, + "azure":{ + "createClient": "createAzure" } }, "models": { @@ -61,6 +67,25 @@ "agentBeastMode": { "temperature": 0.7 }, "fallback": { "temperature": 0 } } + }, + "azure":{ + "default": { + "model": "gpt-4o", + "temperature": 0, + "maxTokens": 10000 + }, + "tools": { + "coder": { "temperature": 0.7 }, + "searchGrounding": { "temperature": 0 }, + "dedup": { "temperature": 0.1 }, + "evaluator": {}, + "errorAnalyzer": {}, + "queryRewriter": { "temperature": 0.1 }, + "agent": { "temperature": 0.7 }, + "agentBeastMode": { "temperature": 0.7 }, + "fallback": { "temperature": 0 } + } + } } } diff --git a/docker-compose.yml b/docker-compose.yml index 9928a990..c578d308 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,6 +6,9 @@ services: context: . dockerfile: Dockerfile environment: + - AZURE_OPENAI_RESOURCE_NAME=${AZURE_OPENAI_RESOURCE_NAME} + - AZURE_OPENAI_API_KEY=${AZURE_OPENAI_API_KEY} + - AZURE_OPENAI_API_VERSION=${AZURE_OPENAI_API_VERSION} - GEMINI_API_KEY=${GEMINI_API_KEY} - OPENAI_API_KEY=${OPENAI_API_KEY} - JINA_API_KEY=${JINA_API_KEY} diff --git a/package-lock.json b/package-lock.json index 80f440b0..bb3f70c4 100644 --- a/package-lock.json +++ b/package-lock.json @@ -9,6 +9,7 @@ "version": "1.0.0", "license": "Apache-2.0", "dependencies": { + "@ai-sdk/azure": "^1.3.16", "@ai-sdk/google": "^1.0.0", "@ai-sdk/openai": "^1.1.9", "@types/jsdom": "^21.1.7", @@ -65,6 +66,52 @@ "zod": "^3.0.0" } }, + "node_modules/@ai-sdk/azure": { + "version": "1.3.16", + "resolved": "https://registry.npmjs.org/@ai-sdk/azure/-/azure-1.3.16.tgz", + "integrity": "sha512-t40iJ6yep0mdBX8nQUz9EvHFHYmsJEoCbrbAyheOE49dpqPXpiRkIEEbmS8hSclZnYRDylHzyq1nEO2YrhHzMg==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/openai": "1.3.15", + "@ai-sdk/provider": "1.1.3", + "@ai-sdk/provider-utils": "2.2.7" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.0.0" + } + }, + "node_modules/@ai-sdk/azure/node_modules/@ai-sdk/provider": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.3.tgz", + "integrity": "sha512-qZMxYJ0qqX/RfnuIaab+zp8UAeJn/ygXXAffR5I4N0n1IrvA6qBsjc8hXLmBiMV2zoXlifkacF7sEFnYnjBcqg==", + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/azure/node_modules/@ai-sdk/provider-utils": { + "version": "2.2.7", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.7.tgz", + "integrity": "sha512-kM0xS3GWg3aMChh9zfeM+80vEZfXzR3JEUBdycZLtbRZ2TRT8xOj3WodGHPb06sUK5yD7pAXC/P7ctsi2fvUGQ==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "1.1.3", + "nanoid": "^3.3.8", + "secure-json-parse": "^2.7.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.23.8" + } + }, "node_modules/@ai-sdk/google": { "version": "1.1.11", "resolved": "https://registry.npmjs.org/@ai-sdk/google/-/google-1.1.11.tgz", @@ -102,13 +149,13 @@ } }, "node_modules/@ai-sdk/openai": { - "version": "1.1.9", - "resolved": "https://registry.npmjs.org/@ai-sdk/openai/-/openai-1.1.9.tgz", - "integrity": "sha512-t/CpC4TLipdbgBJTMX/otzzqzCMBSPQwUOkYPGbT/jyuC86F+YO9o+LS0Ty2pGUE1kyT+B3WmJ318B16ZCg4hw==", + "version": "1.3.15", + "resolved": "https://registry.npmjs.org/@ai-sdk/openai/-/openai-1.3.15.tgz", + "integrity": "sha512-sjUM1A+Pwui+fAn4kmBfVyhJGhqIbsqNzERNPmtRKRBhYG7p946+DTmbJ4lsZG+r4+kMn87aDCYyRexnsnbu5g==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider": "1.0.7", - "@ai-sdk/provider-utils": "2.1.6" + "@ai-sdk/provider": "1.1.3", + "@ai-sdk/provider-utils": "2.2.7" }, "engines": { "node": ">=18" @@ -117,6 +164,35 @@ "zod": "^3.0.0" } }, + "node_modules/@ai-sdk/openai/node_modules/@ai-sdk/provider": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.3.tgz", + "integrity": "sha512-qZMxYJ0qqX/RfnuIaab+zp8UAeJn/ygXXAffR5I4N0n1IrvA6qBsjc8hXLmBiMV2zoXlifkacF7sEFnYnjBcqg==", + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/openai/node_modules/@ai-sdk/provider-utils": { + "version": "2.2.7", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.7.tgz", + "integrity": "sha512-kM0xS3GWg3aMChh9zfeM+80vEZfXzR3JEUBdycZLtbRZ2TRT8xOj3WodGHPb06sUK5yD7pAXC/P7ctsi2fvUGQ==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "1.1.3", + "nanoid": "^3.3.8", + "secure-json-parse": "^2.7.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.23.8" + } + }, "node_modules/@ai-sdk/provider": { "version": "1.0.7", "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.0.7.tgz", diff --git a/package.json b/package.json index 649b5285..90b19b8e 100644 --- a/package.json +++ b/package.json @@ -28,6 +28,7 @@ "dependencies": { "@ai-sdk/google": "^1.0.0", "@ai-sdk/openai": "^1.1.9", + "@ai-sdk/azure": "^1.3.16", "@types/jsdom": "^21.1.7", "ai": "^4.1.26", "axios": "^1.7.9", diff --git a/src/config.ts b/src/config.ts index 76bee3a3..1778548c 100644 --- a/src/config.ts +++ b/src/config.ts @@ -2,12 +2,13 @@ import dotenv from 'dotenv'; import { ProxyAgent, setGlobalDispatcher } from 'undici'; import { createGoogleGenerativeAI } from '@ai-sdk/google'; import { createOpenAI, OpenAIProviderSettings } from '@ai-sdk/openai'; +import { createAzure, AzureOpenAIProviderSettings } from '@ai-sdk/azure' import configJson from '../config.json'; // Load environment variables dotenv.config(); // Types -export type LLMProvider = 'openai' | 'gemini' | 'vertex'; +export type LLMProvider = 'openai' | 'gemini' | 'vertex' | 'azure'; export type ToolName = keyof typeof configJson.models.gemini.tools; // Type definitions for our config structure @@ -46,6 +47,9 @@ export const BRAVE_API_KEY = env.BRAVE_API_KEY; export const SERPER_API_KEY = env.SERPER_API_KEY; export const SEARCH_PROVIDER = configJson.defaults.search_provider; export const STEP_SLEEP = configJson.defaults.step_sleep; +export const AZURE_OPENAI_RESOURCE_NAME = env.AZURE_OPENAI_RESOURCE_NAME; +export const AZURE_OPENAI_API_KEY = env.AZURE_OPENAI_API_KEY; +export const AZURE_OPENAI_API_VERSION = env.AZURE_OPENAI_API_VERSION; // Determine LLM provider export const LLM_PROVIDER: LLMProvider = (() => { @@ -57,7 +61,7 @@ export const LLM_PROVIDER: LLMProvider = (() => { })(); function isValidProvider(provider: string): provider is LLMProvider { - return provider === 'openai' || provider === 'gemini' || provider === 'vertex'; + return provider === 'openai' || provider === 'gemini' || provider === 'vertex' || provider === 'azure'; } interface ToolConfig { @@ -92,6 +96,27 @@ export function getMaxTokens(toolName: ToolName): number { export function getModel(toolName: ToolName) { const config = getToolConfig(toolName); const providerConfig = (configJson.providers as Record)[LLM_PROVIDER]; + if (LLM_PROVIDER === 'azure') { + if (!AZURE_OPENAI_API_KEY) { + throw new Error('AZURE_OPENAI_API_KEY not found'); + } + + if (!AZURE_OPENAI_RESOURCE_NAME) { + throw new Error('AZURE_OPENAI_RESOURCE_NAME not found'); + } + + if (!AZURE_OPENAI_API_VERSION) { + throw new Error('AZURE_OPENAI_API_VERSION not found'); + } + + const opt: AzureOpenAIProviderSettings = { + apiKey: AZURE_OPENAI_API_KEY, + resourceName: AZURE_OPENAI_RESOURCE_NAME, + apiVersion: AZURE_OPENAI_API_VERSION + }; + + return createAzure(opt)(config.model); + } if (LLM_PROVIDER === 'openai') { if (!OPENAI_API_KEY) { @@ -131,29 +156,54 @@ export function getModel(toolName: ToolName) { // Validate required environment variables if (LLM_PROVIDER === 'gemini' && !GEMINI_API_KEY) throw new Error("GEMINI_API_KEY not found"); if (LLM_PROVIDER === 'openai' && !OPENAI_API_KEY) throw new Error("OPENAI_API_KEY not found"); +if (LLM_PROVIDER === 'azure' && !AZURE_OPENAI_API_KEY) throw new Error("AZURE_OPENAI_API_KEY not found"); +if (LLM_PROVIDER === 'azure' && !AZURE_OPENAI_RESOURCE_NAME) throw new Error("AZURE_OPENAI_RESOURCE_NAME not found"); +if (LLM_PROVIDER === 'azure' && !AZURE_OPENAI_API_VERSION) throw new Error("AZURE_OPENAI_API_VERSION not found"); if (!JINA_API_KEY) throw new Error("JINA_API_KEY not found"); -// Log all configurations +const providerModels: Record = { + openai : configJson.models.openai.default.model, + gemini : configJson.models.gemini.default.model, + vertex : configJson.models.gemini.default.model, // vertex uses Gemini models + azure : configJson.models.azure.default.model, +}; + +const providerExtras: Record> = { + openai : { baseUrl: OPENAI_BASE_URL }, + azure : { resourceName: AZURE_OPENAI_RESOURCE_NAME, apiVersion: AZURE_OPENAI_API_VERSION }, + gemini : {}, + vertex : {}, +}; + +const providerNameForTools: Record = { + openai : 'openai', + gemini : 'gemini', + vertex : 'gemini', // vertex shares the Gemini tool settings + azure : 'azure', +}; + const configSummary = { provider: { - name: LLM_PROVIDER, - model: LLM_PROVIDER === 'openai' - ? configJson.models.openai.default.model - : configJson.models.gemini.default.model, - ...(LLM_PROVIDER === 'openai' && { baseUrl: OPENAI_BASE_URL }) - }, - search: { - provider: SEARCH_PROVIDER + name : LLM_PROVIDER, + model: providerModels[LLM_PROVIDER], + ...providerExtras[LLM_PROVIDER], // adds baseUrl / endpoint when present }, + + search: { provider: SEARCH_PROVIDER }, + tools: Object.fromEntries( - Object.keys(configJson.models[LLM_PROVIDER === 'vertex' ? 'gemini' : LLM_PROVIDER].tools).map(name => [ + Object.keys( + configJson.models[providerNameForTools[LLM_PROVIDER]].tools + ).map(name => [ name, - getToolConfig(name as ToolName) - ]) + getToolConfig(name as ToolName), + ]), ), - defaults: { - stepSleep: STEP_SLEEP - } + + defaults: { stepSleep: STEP_SLEEP }, }; -console.log('Configuration Summary:', JSON.stringify(configSummary, null, 2)); +console.log( + 'Configuration Summary:', + JSON.stringify(configSummary, null, 2), +);